diff --git "a/cisen/engine/engine.py" "b/cisen/engine/engine.py" new file mode 100644--- /dev/null +++ "b/cisen/engine/engine.py" @@ -0,0 +1,2973 @@ + + +import os +import time +from tqdm import tqdm +import numpy as np +import torch +import torch.nn as nn +import torch.cuda.amp as amp +import torch.distributed as dist +import torch.nn.functional as F +from torch.autograd import Variable +from loguru import logger +from utils.dataset import tokenize +from utils.misc import AverageMeter, ProgressMeter, trainMetric, compute_AP, compute_mAP, \ + adjust_learning_rate, compute_acc, compute_F1, compute_ACG, token_wise_similarity + +from utils.hash import CalcSim, Calcloss, CalcNTXentLoss, GenerateCode, CalcMap, CalcTopMap +from torch.cuda.amp import autocast + +def train_one_epoch(train_loader, model, optimizer, scheduler, scaler, epoch, args): + batch_time = AverageMeter('Batch', ':2.2f') + data_time = AverageMeter('Data', ':2.2f') + lr = AverageMeter('Lr', ':1.6f') + loss_meter = AverageMeter('Loss', ':2.4f') + prec_meter = AverageMeter('Precision','2.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, lr, loss_meter], + prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + eval = [] + + model.train() + time.sleep(2) + end = time.time() + for i, (idx, image, text, label, target) in enumerate(train_loader): + + data_time.update(time.time() - end) + # data + image = torch.tensor(image).cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + with amp.autocast(): + score = model(image, text) + # forward + # with amp.autocast(): + # score = model(image, text) + loss = F.cross_entropy(score, target) + # backward + F.multi_head_attention_forward() + optimizer.zero_grad() + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + # metric + prec = trainMetric(score, label) + eval.append(prec.detach().cpu()) + + + + loss_meter.update(loss.item(), image.size(0)) + prec_meter.update(prec.item(), image.size(0)) + lr.update(scheduler.get_last_lr()[-1]) + batch_time.update(time.time() - end) + + if (i + 1) % args.print_freq == 0: + progress.display(i + 1) + + logger.info('precision:{:.2f}%'.format(np.sum(eval) / len(train_loader.dataset) * 100.)) + +""" + supervised +""" +def train_hash_one_epoch(train_loader, model, optimizer, scheduler, scaler, epoch, H, Hi, Ht, train_labels, args): + batch_time = AverageMeter('Batch', ':2.2f') + data_time = AverageMeter('Data', ':2.2f') + lr = AverageMeter('Lr', ':1.6f') + loss_meter = AverageMeter('Loss', ':2.4f') + prec_meter = AverageMeter('Precision','2.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, lr, loss_meter], + prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + eval = [] + + model.train() + time.sleep(2) + end = time.time() + for i, (idx, image, text, label, target) in enumerate(train_loader): + + data_time.update(time.time() - end) + # data + image = torch.tensor(image).cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + train_labels = train_labels.cuda(non_blocking=True) + + # forward + with amp.autocast(): + img_hash, txt_hash, out_hash, score = model(image, text) + S = CalcSim(target, train_labels) + + H[idx, :] = out_hash.float() + Hi[idx, :] = img_hash.float() + Ht[idx, :] = txt_hash.float() + Bbatch = torch.sign(out_hash) + + + + num_train = len(train_loader.dataset) + num_batch = len(target) + # calc loss + loss_semantic = F.cross_entropy(score, target) + _, _, loss_p = Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args) + loss = (1 - args.eta) * loss_p + args.eta * loss_semantic + + # backward + optimizer.zero_grad() + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + # metric + prec = trainMetric(score, label) + eval.append(prec.detach().cpu()) + + + + loss_meter.update(loss.item(), image.size(0)) + prec_meter.update(prec.item(), image.size(0)) + lr.update(scheduler.get_last_lr()[-1]) + batch_time.update(time.time() - end) + + if (i + 1) % args.print_freq == 0: + progress.display(i + 1) + + + + logger.info('Classification precision:{:.2f}%'.format(np.sum(eval) / len(train_loader.dataset) * 100.)) + + return H, Hi, Ht + +""" + unsupervised +""" +def train_final_hash_one_epoch(train_loader, model, Criterion, optimizer, scheduler, scaler, epoch, H, Hi, Ht, train_labels, args): + batch_time = AverageMeter('Batch', ':2.2f') + data_time = AverageMeter('Data', ':2.2f') + lr = AverageMeter('Lr', ':1.6f') + loss_meter = AverageMeter('Loss', ':2.4f') + prec_meter = AverageMeter('Precision','2.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, lr, loss_meter], + prefix="Training: Epoch=[{}/{}] ".format(epoch , args.epochs)) + + eval = [] + + model.train() + time.sleep(2) + end = time.time() + for i, (idx, image, text, label, target) in enumerate(train_loader): + + data_time.update(time.time() - end) + # data + image = torch.tensor(image).cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + train_labels = train_labels.cuda(non_blocking=True) + + # forward + with amp.autocast(): + img_hash, txt_hash, out_hash, score = model(image, text) + S = CalcSim(target, train_labels) + + H[idx, :] = out_hash.float() + Hi[idx, :] = img_hash.float() + Ht[idx, :] = txt_hash.float() + Bbatch = torch.sign(out_hash) + + + + num_train = len(train_loader.dataset) + num_batch = len(target) + + # calc loss + loss_ntxent = CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args) + + _, _, loss_p = Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args) + + loss = (1 - args.eta) * loss_p + args.eta * loss_ntxent + + # backward + optimizer.zero_grad() + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + # metric + prec = trainMetric(score, label) + eval.append(prec.detach().cpu()) + + + + loss_meter.update(loss.item(), image.size(0)) + prec_meter.update(prec.item(), image.size(0)) + lr.update(scheduler.get_last_lr()[-1]) + batch_time.update(time.time() - end) + + if (i + 1) % args.print_freq == 0: + progress.display(i + 1) + + + + logger.info('Classification precision:{:.2f}%'.format(np.sum(eval) / len(train_loader.dataset) * 100.)) + + return H, Hi, Ht + + +@torch.no_grad() +def validate_one_epoch(val_loader, model, epoch, args): + prec = [] + model.eval() + time.sleep(2) + loss = [] + for i, (idx, image, text, label, target) in enumerate(val_loader): + # data + image = torch.tensor(image).cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # inference + pred = model(image, text) + loss.append(F.cross_entropy(pred, target).detach().cpu()) + + pred = torch.argmax(pred, dim=1) + res = torch.sum(pred == label) + prec.append(res.detach().cpu()) + + res = 100. * np.sum(prec) / len(val_loader.dataset) + l = np.mean(loss) + head = 'Evaluation: Epoch=[{}/{}] Prec={:.2f}% loss={}'.format( + epoch + 1, args.epochs, res, l) + + logger.info(head) + + return res + +@torch.no_grad() +def validate_hash_one_epoch(query_loader, db_loader, model, epoch, query_labels, db_labels, args): + prec = [] + model.eval() + time.sleep(2) + loss = [] + qB, _, _ = GenerateCode(model, query_loader, args) + dB, _, _ = GenerateCode(model, db_loader, args) + map = CalcTopMap(qB, dB, query_labels, db_labels) + + + head = 'Evaluation: Epoch=[{}/{}] Map={:.2f}'.format( + epoch + 1, args.epochs, map) + + logger.info(head) + + return map + +@torch.no_grad() +def validate_final_hash_one_epoch(query_loader, db_loader, model, epoch, query_labels, db_labels, args): + prec = [] + model.eval() + time.sleep(2) + loss = [] + qB, qBi, qBt = GenerateCode(model, query_loader, args) + dB, dBi, dBt = GenerateCode(model, db_loader, args) + + map = CalcTopMap(qB, dB, query_labels, db_labels, 20) + + head = 'Evaluation: Epoch=[{}/{}] Map={:.2f}'.format( + epoch + 1, args.epochs, map) + logger.info(head) + + map1 = CalcTopMap(qBi, dBt, query_labels, db_labels, 20) + head = 'Evaluation: Epoch=[{}/{}] i -> t Map={:.2f}'.format( + epoch + 1, args.epochs, map1) + logger.info(head) + + map2 = CalcTopMap(qBt, dBi, query_labels, db_labels, 20) + head = 'Evaluation: Epoch=[{}/{}] t -> i Map={:.2f}'.format( + epoch + 1, args.epochs, map2) + logger.info(head) + + map3 = CalcTopMap(qBt, dBt, query_labels, db_labels, 20) + head = 'Evaluation: Epoch=[{}/{}] t -> t Map={:.2f}'.format( + epoch + 1, args.epochs, map3) + logger.info(head) + + map4 = CalcTopMap(qBi, dBi, query_labels, db_labels, 20) + head = 'Evaluation: Epoch=[{}/{}] i -> i Map={:.2f}'.format( + epoch + 1, args.epochs, map4) + logger.info(head) + + + + return map + + + +@torch.no_grad() +def inference(test_loader, model, args): + prec = [] + wrong = [] + cnt = 0 + tbar = tqdm(test_loader, desc="Inference:", ncols=100) + model.eval() + time.sleep(2) + for idx, image, text, label, target in tbar: + # data + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking = True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # inference + pred = model(image, text) + pred = torch.argmax(pred, dim=1) + if pred == label: + cnt += 1 + else: + wrong.append([idx, pred, label]) + + prec.append(pred) + + logger.info('=> Metric Calculation <=') + logger.info("Prcision = {}".format(cnt / len(test_loader.dataset))) + return wrong + + +@torch.no_grad() +def inference_hash(test_loader, model, args): + prec = [] + wrong = [] + cnt = 0 + tbar = tqdm(test_loader, desc="Inference:", ncols=100) + model.eval() + time.sleep(2) + for idx, image, text, label, target in tbar: + # data + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking = True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # inference + pred = model(image, text) + pred = torch.argmax(pred, dim=1) + if pred == label: + cnt += 1 + else: + wrong.append([idx, pred, label]) + + prec.append(pred) + + logger.info('=> Metric Calculation <=') + logger.info("Prcision = {}".format(cnt / len(test_loader.dataset))) + return wrong + +def train_clip_one_epoch(train_loader, model, optimizer, scheduler, scaler, epoch, args): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + eval = [] + + model.train() + time.sleep(2) + end = time.time() + criterion = nn.CrossEntropyLoss() + for i, (idx, image, text, label, target) in enumerate(train_loader): + + # data_time.update(time.time() - end) + # data + image = torch.tensor(image).cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + label = label.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # forward + with amp.autocast(): + logits_per_image, logits_per_text = model(image, text) + print(logits_per_image) + ground_truth = torch.arange(len(image)).cuda() + loss = 1/2 * (criterion(logits_per_image,ground_truth) + criterion(logits_per_text,ground_truth)) + + print(loss) + # backward + optimizer.zero_grad() + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + +def train_1CISEN_one_epoch(train_loader, model, optimizer, scaler, args, epoch): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + eval = [] + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + + for i, (image, text, _) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + image = image.cuda() + text = text.cuda() + # db = text_db[i].cuda() + # label = label.cuda(non_blocking=True) + # target = target.cuda() + optimizer.zero_grad() + # forward + with amp.autocast(): + + loss, _, _, _ = model(image, text, '1st') + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(tloss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + return mean_loss + +def train_2CISEN_one_epoch(train_loader, model, optimizer, scaler, args, stage): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (image, text, _) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + image = image.cuda() + text = text.cuda() + # db = text_db[i].cuda() + # label = label.cuda(non_blocking=True) + # target = target.cuda() + optimizer.zero_grad() + # forward + with amp.autocast(): + + loss, _, _, _ = model(image, text, stage) + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(tloss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + return mean_loss + + +def train_Filip_one_epoch(train_loader, model, optimizer, scaler, args, stage): + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (img, text, _) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + img = img.cuda(non_blocking=True) + text = text.squeeze(1).cuda(non_blocking=True) + + optimizer.zero_grad() + # forward + with amp.autocast(): + loss, _, _, _ = model(img, text) + + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + return mean_loss + +def train_Blip_one_epoch(train_loader, model, optimizer, scaler, args, stage): + + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (inputs, _) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + for key in inputs: + inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + + + optimizer.zero_grad() + # forward + with amp.autocast(): + loss, _, _, _ = model(inputs) + + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + return mean_loss + +def train_Albef_one_epoch(train_loader, model, optimizer, scaler, args, stage): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (inputs, _) in enumerate(tqdm(train_loader)): + + + inputs['img'] = inputs['img'].cuda() + + optimizer.zero_grad() + # forward + with amp.autocast(): + loss, _, _, _ = model(inputs) + + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + return mean_loss + +def train_Align_one_epoch(train_loader, model, optimizer, scaler, args, stage): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (inputs, _) in enumerate(tqdm(train_loader)): + + for key in inputs: + inputs[key] = inputs[key].squeeze(1).cuda() + + optimizer.zero_grad() + # forward + with amp.autocast(): + loss, _, _, _ = model(inputs) + + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + + return mean_loss + +def truncate_text(text, max_length): + if len(text) > max_length: + text = text[:max_length] + return text + +def train_CISEN_lclip(train_loader, model, optimizer, scaler, args, stage): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (image, text, _ ) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + # db = text_db[i].cuda() + # label = label.cuda(non_blocking=True) + # target = target.cuda() + optimizer.zero_grad() + # forward + with amp.autocast(): + + loss, _, _, _ = model(image, text, stage) + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + return mean_loss + +def train_3CISEN_one_epoch(train_loader, model, optimizer, scaler, args, stage): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + model.train() + end = time.time() + mean_loss = 0 + + # text_db = [text.cuda() for text in text_db] + # collect text + + for i, (image, text, _ ) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + image = image.cuda() + text = text.cuda() + # db = text_db[i].cuda() + # label = label.cuda(non_blocking=True) + # target = target.cuda() + optimizer.zero_grad() + # forward + with amp.autocast(): + + loss, _, _, _ = model(image, text, stage) + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + + scaler.scale(loss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + return mean_loss + +# old version +def train_CISEN(train_loader, model, optimizer, scaler, args, epoch): + # batch_time = AverageMeter('Batch', ':2.2f') + # data_time = AverageMeter('Data', ':2.2f') + # lr = AverageMeter('Lr', ':1.6f') + # loss_meter = AverageMeter('Loss', ':2.4f') + # prec_meter = AverageMeter('Precision','2.2f') + # progress = ProgressMeter( + # len(train_loader), + # [batch_time, data_time, lr, loss_meter], + # prefix="Training: Epoch=[{}/{}] ".format(epoch + 1, args.epochs)) + + eval = [] + + model.train() + end = time.time() + mean_loss = 0 + text_db = [] + # collect text + # for i, (image, text, target) in enumerate(tqdm(train_loader)): + # text = text.cuda() + # text_db.append(text) + + for i, (image, text, target) in enumerate(tqdm(train_loader)): + # adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + # data_time.update(time.time() - end) + # data + image = torch.tensor(image).cuda() + text = text.cuda() + db = text_db[i] + # label = label.cuda(non_blocking=True) + target = target.cuda() + + # forward + with amp.autocast(): + optimizer.zero_grad() + loss, fv, ft, fl = model(image, text, db) + tloss = torch.mean(loss) + mean_loss += tloss.item() + # backward + scaler.scale(tloss).backward() + if args.max_norm: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) + + scaler.step(optimizer) + scaler.update() + return mean_loss + + +# def eval_CISEN(val_loader, model, args): +# logger.info("=======================TEST MODE=======================") +# model.eval() +# # cnt = 0 +# test_batch_size = args.test_batch_size +# # Top1_acc = torch.empty(len_testdataset, 3) +# # Topk_acc = torch.empty(len_testdataset, 3) +# Top1_acc = [] +# Topk_acc = [] +# topk = 5 +# +# result = dict() +# +# for i, (image, text, _) in enumerate(tqdm(val_loader)): +# targets = torch.arange(image.shape[0]) +# +# with torch.no_grad(): +# loss, enhanced_features, image_features, text_features = model(image, text, 'test') +# logits_per_text = text_features @ enhanced_features.t() +# logits = image_features @ text_features.t() +# +# logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) +# logits = logits / logits.norm(dim=-1, keepdim=True) +# +# raw_acc1, raw_acck = compute_acc(logits, targets, topk) +# augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) +# #mAP +# +# +# +# +# # Recall +# result.setdefault("raw_Top1_acc", []).append(raw_acc1) +# Topk_acc = "raw_Top" + str(topk) + "_acc" +# result.setdefault(Topk_acc, []).append(raw_acck) +# +# result.setdefault("augt_Top1_acc", []).append(augt_acc1) +# Topk_acc = "augt_Top" + str(topk) + "_acc" +# result.setdefault(Topk_acc, []).append(augt_acck) +# +# #mAP +# +# +# +# for key, acc in result.items(): +# acc = torch.cat(acc) +# acc = acc.float().mean().item() * 100 +# result[key] = acc + + + +def test_CISEN_one_epoch(test_loader, model, len_testdataset, args): + model.eval() + # cnt = 0 + test_batch_size = args.test_batch_size + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = 5 + + result = dict() + results = dict() + + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, '1st') + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault("i->t:acg", []).append(it_acg) + result.setdefault("i->t:ndcg", []).append(it_ndcg) + result.setdefault("i->t:map", []).append(it_map) + result.setdefault("i->t:wmap", []).append(it_wmap) + + result.setdefault("t->i:acg", []).append(ti_acg) + result.setdefault("t->i:ndcg", []).append(ti_ndcg) + result.setdefault("t->i:map", []).append(ti_map) + result.setdefault("t->i:wmap", []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + result.setdefault("i->t(aug):acg", []).append(it_aug_acg) + result.setdefault("i->t(aug):ndcg", []).append(it_aug_ndcg) + result.setdefault("i->t(aug):map", []).append(it_aug_map) + result.setdefault("i->t(aug):wmap", []).append(it_aug_wmap) + + result.setdefault("t->i(aug):acg", []).append(ti_aug_acg) + result.setdefault("t->i(aug):ndcg", []).append(ti_aug_ndcg) + result.setdefault("t->i(aug):map", []).append(ti_aug_map) + result.setdefault("t->i(aug):wmap", []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + + for key, value in result.items(): + acc = torch.cat(value) + acc = acc.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + return results + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +def eval_Filip_test(test_loader, model, stage, model_name): + model.eval() + + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (img, text, targets) in enumerate(tqdm(test_loader)): + img = img + text = text.squeeze(1) + targets = targets + + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(img, text) + + if stage == "2nd": + image_features = None + + if image_features is not None: + it_logits = ( + token_wise_similarity(image_features, text_features) + if model.is_token_wise + else (image_features @ text_features.T) + ) + ti_logits = ( + token_wise_similarity(text_features, image_features) + if model.is_token_wise + else (text_features @ image_features.T) + ) + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_aug_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Blip_test(test_loader, model, stage, model_name): + model.eval() + + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + for key in inputs: + # inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + inputs[key] = inputs[key].squeeze(1).cuda() + targets = targets.cuda() + del inputs['token_type_ids'] + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(inputs) + + if stage == "2nd": + image_features = None + + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_aug_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Albef_test(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + inputs['img'] = inputs['img'] + # targets = targets.cuda(non_blocking=True) + + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(inputs) + + if stage == "2nd": + image_features = None + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_aug_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Align_test(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + for key in inputs: + # inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + inputs[key] = inputs[key].squeeze(1) + # targets = targets.cuda(non_blocking=True) + + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(inputs) + if stage == "2nd": + image_features = None + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_aug_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Filip_foo(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = int(model_name.split('@')[-1]) + mode = model_name.split('@')[0].split(':')[-1] + task = model_name.split('@')[0].split(':')[0] + if 'aug' in task: + task = task[:-5] + + result = dict() + results = dict() + + for i, (img, text, targets) in enumerate(tqdm(test_loader)): + img = img.cuda(non_blocking=True) + text = text.squeeze(1).cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(img, text) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = ( + token_wise_similarity(image_features, text_features) + if model.is_token_wise + else (image_features @ text_features.T) + ) + ti_logits = ( + token_wise_similarity(text_features, image_features) + if model.is_token_wise + else (text_features @ image_features.T) + ) + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + if task == 'i->t': + if 'acg' == mode: + result.setdefault(model_name, []).append(it_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(it_wmap) + else: + print('mode should be in acg ndcg map wmap') + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + if task == 'i->t': + # record + if 'acg' == mode: + result.setdefault(model_name, []).append(it_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name , []).append(it_aug_wmap) + + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_aug_wmap) + + + # 在主 GPU 上进行后续的处理或输出 + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Blip_foo(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = int(model_name.split('@')[-1]) + mode = model_name.split('@')[0].split(':')[-1] + task = model_name.split('@')[0].split(':')[0] + if 'aug' in task: + task = task[:-5] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + for key in inputs: + # inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + del inputs['token_type_ids'] + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(inputs) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + if task == 'i->t': + if 'acg' == mode: + result.setdefault(model_name, []).append(it_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(it_wmap) + else: + print('mode should be in acg ndcg map wmap') + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + if task == 'i->t': + # record + if 'acg' == mode: + result.setdefault(model_name, []).append(it_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name , []).append(it_aug_wmap) + + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_aug_wmap) + + + # 在主 GPU 上进行后续的处理或输出 + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Albef_foo(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = int(model_name.split('@')[-1]) + mode = model_name.split('@')[0].split(':')[-1] + task = model_name.split('@')[0].split(':')[0] + if 'aug' in task: + task = task[:-5] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + inputs['img'] = inputs['img'].cuda() + # targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(inputs) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + if task == 'i->t': + if 'acg' == mode: + result.setdefault(model_name, []).append(it_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(it_wmap) + else: + print('mode should be in acg ndcg map wmap') + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + if task == 'i->t': + # record + if 'acg' == mode: + result.setdefault(model_name, []).append(it_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name , []).append(it_aug_wmap) + + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_aug_wmap) + + + # 在主 GPU 上进行后续的处理或输出 + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_Align_foo(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = int(model_name.split('@')[-1]) + mode = model_name.split('@')[0].split(':')[-1] + task = model_name.split('@')[0].split(':')[0] + if 'aug' in task: + task = task[:-5] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + for key in inputs: + # inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + inputs[key] = inputs[key].squeeze(1) + # targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(inputs) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + if task == 'i->t': + if 'acg' == mode: + result.setdefault(model_name, []).append(it_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(it_wmap) + else: + print('mode should be in acg ndcg map wmap') + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + if task == 'i->t': + # record + if 'acg' == mode: + result.setdefault(model_name, []).append(it_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name , []).append(it_aug_wmap) + + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_aug_wmap) + + + # 在主 GPU 上进行后续的处理或输出 + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + +def eval_CISEN_foo(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = int(model_name.split('@')[-1]) + mode = model_name.split('@')[0].split(':')[-1] + task = model_name.split('@')[0].split(':')[0] + if 'aug' in task: + task = task[:-5] + + result = dict() + results = dict() + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, stage) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + if task == 'i->t': + if 'acg' == mode: + result.setdefault(model_name, []).append(it_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(it_wmap) + else: + print('mode should be in acg ndcg map wmap') + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + if task == 'i->t': + # record + if 'acg' == mode: + result.setdefault(model_name, []).append(it_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(it_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(it_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name , []).append(it_aug_wmap) + + elif task == 't->i': + if 'acg' == mode: + result.setdefault(model_name, []).append(ti_aug_acg) + elif 'ndcg' == mode: + result.setdefault(model_name, []).append(ti_aug_ndcg) + elif 'map' == mode: + result.setdefault(model_name, []).append(ti_aug_map) + elif 'wmap' == mode: + result.setdefault(model_name, []).append(ti_aug_wmap) + + + # 在主 GPU 上进行后续的处理或输出 + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +def eval_CISEN_(test_loader, model, stage, args): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, stage) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + # logger.info(f"{key}: {acc}") + + + + return results + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +def eval_CISEN_test(test_loader, model, stage, model_name): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, stage) + if stage == "2nd": + image_features = None + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_aug_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + + + for key, value in result.items(): + value = torch.cat(value) + acc = value.float().mean().item() + results.setdefault(key, acc) + + + return results + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +def test1_Filip_one_epoch(test_loader, model, stage, args): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (img, text, targets) in enumerate(tqdm(test_loader)): + img = img.cuda(non_blocking=True) + text = text.squeeze(1).cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(img, text) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = ( + token_wise_similarity(image_features, text_features) + if model.module.is_token_wise + else (image_features @ text_features.T) + ) + ti_logits = ( + token_wise_similarity(text_features, image_features) + if model.module.is_token_wise + else (text_features @ image_features.T) + ) + + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->i(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + return results + +def test1_Blip_one_epoch(test_loader, model, stage, args): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5, 10, 20, 50, 100] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + + for key in inputs: + inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + # del inputs['token_type_ids'] + targets = targets.cuda(non_blocking=True) + + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(inputs) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->i(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap, []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + return results + +def test1_Albef_one_epoch(test_loader, model, stage, args): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5,10,20,50,100] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + + inputs['img'] = inputs['img'].cuda(non_blocking=True) + + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(inputs) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->i(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap , []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + + return results + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +def test1_Align_one_epoch(test_loader, model, stage, args): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5,10,20,50,100] + + result = dict() + results = dict() + + for i, (inputs, targets) in enumerate(tqdm(test_loader)): + for key in inputs: + inputs[key] = inputs[key].squeeze(1).cuda(non_blocking=True) + + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + with autocast(): + loss, enhanced_features, image_features, text_features = model(inputs) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->i(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap , []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + + return results + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + + +def test1_CISEN_one_epoch(test_loader, model, stage, args): + model.eval() + # cnt = 0 + + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5,10,20,50,100] + + result = dict() + results = dict() + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, stage) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->i(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap , []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + + return results + +def test_CISEN_lclip(test_loader, model, stage, args): + model.eval() + # cnt = 0 + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + k = [5,10,20,50,100] + + result = dict() + results = dict() + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, stage) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + + IT_acg = "i->t:acg@" + str(topk) + IT_ndcg = "i->t:ndcg@" + str(topk) + IT_map = "i->t:map@" + str(topk) + IT_wmap = "i->t:wmap@" + str(topk) + + TI_acg = "t->i:acg@" + str(topk) + TI_ndcg = "t->i:ndcg@" + str(topk) + TI_map = "t->i:map@" + str(topk) + TI_wmap = "t->i:wmap@" + str(topk) + + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault(IT_acg, []).append(it_acg) + result.setdefault(IT_ndcg, []).append(it_ndcg) + result.setdefault(IT_map, []).append(it_map) + result.setdefault(IT_wmap, []).append(it_wmap) + + result.setdefault(TI_acg, []).append(ti_acg) + result.setdefault(TI_ndcg, []).append(ti_ndcg) + result.setdefault(TI_map, []).append(ti_map) + result.setdefault(TI_wmap, []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + for topk in k: + if topk > it_logits.shape[0]: + continue + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + IT_aug_acg = "i->t(aug):acg@" + str(topk) + IT_aug_ndcg = "i->t(aug):ndcg@" + str(topk) + IT_aug_map = "i->t(aug):map@" + str(topk) + IT_aug_wmap = "i->t(aug):wmap@" + str(topk) + + TI_aug_acg = "t->i(aug):acg@" + str(topk) + TI_aug_ndcg = "t->i(aug):ndcg@" + str(topk) + TI_aug_map = "t->i(aug):map@" + str(topk) + TI_aug_wmap = "t->i(aug):wmap@" + str(topk) + + result.setdefault(IT_aug_acg, []).append(it_aug_acg) + result.setdefault(IT_aug_ndcg, []).append(it_aug_ndcg) + result.setdefault(IT_aug_map, []).append(it_aug_map) + result.setdefault(IT_aug_wmap , []).append(it_aug_wmap) + + result.setdefault(TI_aug_acg, []).append(ti_aug_acg) + result.setdefault(TI_aug_ndcg, []).append(ti_aug_ndcg) + result.setdefault(TI_aug_map, []).append(ti_aug_map) + result.setdefault(TI_aug_wmap, []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + results.setdefault(key, acc) + logger.info(f"{key}: {acc}") + + + return results + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +def test2_CISEN_one_epoch(test_loader, model, stage, args): + model.eval() + # cnt = 0 + test_batch_size = args.test_batch_size + # Top1_acc = torch.empty(len_testdataset, 3) + # Topk_acc = torch.empty(len_testdataset, 3) + Top1_acc = [] + Topk_acc = [] + topk = 5 + + result = dict() + + for i, (image, text, targets) in enumerate(tqdm(test_loader)): + image = image.cuda(non_blocking=True) + text = text.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # targets = torch.arange(image.shape[0]).cuda() + + with torch.no_grad(): + loss, enhanced_features, image_features, text_features = model(image, text, stage) + # logits_per_image = image_features @ enhanced_features.t() + # logits_per_text = enhanced_features @ text_features.t() + if image_features is not None: + it_logits = image_features @ text_features.t() + ti_logits = text_features @ image_features.t() + it_logits = it_logits / it_logits.norm(dim=-1, keepdim=True) + ti_logits = ti_logits / ti_logits.norm(dim=-1, keepdim=True) + it_acg, it_ndcg, it_map, it_wmap = compute_ACG(it_logits, targets, topk) + ti_acg, ti_ndcg, ti_map, ti_wmap = compute_ACG(ti_logits, targets, topk) + # record + result.setdefault("i->t:acg", []).append(it_acg) + result.setdefault("i->t:ndcg", []).append(it_ndcg) + result.setdefault("i->t:map", []).append(it_map) + result.setdefault("i->t:wmap", []).append(it_wmap) + + result.setdefault("t->i:acg", []).append(ti_acg) + result.setdefault("t->i:ndcg", []).append(ti_ndcg) + result.setdefault("t->i:map", []).append(ti_map) + result.setdefault("t->i:wmap", []).append(ti_wmap) + if enhanced_features is not None: + it_aug_logits = enhanced_features @ text_features.t() + ti_aug_logits = text_features @ enhanced_features.t() + it_aug_logits = it_aug_logits / it_aug_logits.norm(dim=-1, keepdim=True) + ti_aug_logits = ti_aug_logits / ti_aug_logits.norm(dim=-1, keepdim=True) + it_aug_acg, it_aug_ndcg, it_aug_map, it_aug_wmap = compute_ACG(it_aug_logits, targets, topk) + ti_aug_acg, ti_aug_ndcg, ti_aug_map, ti_aug_wmap = compute_ACG(ti_aug_logits, targets, topk) + # record + result.setdefault("i->t(aug):acg", []).append(it_aug_acg) + result.setdefault("i->t(aug):ndcg", []).append(it_aug_ndcg) + result.setdefault("i->t(aug):map", []).append(it_aug_map) + result.setdefault("i->t(aug):wmap", []).append(it_aug_wmap) + + result.setdefault("t->i(aug):acg", []).append(ti_aug_acg) + result.setdefault("t->i(aug):ndcg", []).append(ti_aug_ndcg) + result.setdefault("t->i(aug):map", []).append(ti_aug_map) + result.setdefault("t->i(aug):wmap", []).append(ti_aug_wmap) + + # logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) + # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) + + # augi_acc1, augi_acck = compute_acc(logits_per_image, targets, topk) + # augt_acc1, augt_acck = compute_acc(logits_per_text, targets, topk) + + # Topk_acc = "raw_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(raw_acck) + + # result.setdefault("augi_Top1_acc", []).append(augi_acc1) + # Topk_acc = "augi_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augi_acck) + + # result.setdefault("augt_Top1_acc", []).append(augt_acc1) + # Topk_acc = "augt_Top" + str(topk) + "_acc" + # result.setdefault(Topk_acc, []).append(augt_acck) + + gathered_results = dict() + for key, value in result.items(): + tensor_list = [torch.zeros_like(torch.cat(value)).cuda() for _ in range(dist.get_world_size())] + + dist.all_gather(tensor_list, torch.cat(value)) + + gathered_results[key] = torch.cat(tensor_list) + + # 在主 GPU 上进行后续的处理或输出 + if dist.get_rank() == 0: + # 输出汇总结果 + for key, value in gathered_results.items(): + acc = value.float().mean().item() + logger.info(f"{key}: {acc}") + + + + # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) + # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) + # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5))) + +# def test2_CISEN_one_epoch(test_loader, model, len_testdataset, args): +# logger.info("=======================TEST MODE=======================") +# model.eval() +# # cnt = 0 +# test_batch_size = args.test_batch_size +# # Top1_acc = torch.empty(len_testdataset, 3) +# # Topk_acc = torch.empty(len_testdataset, 3) +# Top1_acc = [] +# Topk_acc = [] +# topk = 5 +# result = dict() +# +# for i, (image, text, _) in enumerate(tqdm(test_loader)): +# text = text.cuda() +# targets = torch.arange(image.shape[0]).to(text.device) +# for idx in range(image.shape[0]): +# with torch.no_grad(): +# imgs = image[idx,:,:].unsqueeze(0) +# imgs = imgs.cuda() +# loss, enhanced_features, image_features, text_features = model(imgs, text.unsqueeze(0), '2nd') +# enhanced_features = enhanced_features.squeeze(0) +# # enhanced_features = torch.stack(enhanced_features).squeeze(1) +# # image2text +# logits_per_image = image_features @ enhanced_features.t() +# # logits_per_text = text_features @ enhanced_features.t() +# +# logits_per_image = logits_per_image / logits_per_image.norm(dim=-1, keepdim=True) +# # logits_per_text = logits_per_text / logits_per_text.norm(dim=-1, keepdim=True) +# +# augi_acc1, augi_acck = compute_acc(logits_per_image, targets[idx], topk) +# # augt_acc1, augt_acck = compute_acc(logits_per_text, targets[idx], topk) +# +# result.setdefault("augi_Top1_acc", []).append(augi_acc1) +# Topk_acc = "augi_Top" + str(topk) + "_acc" +# result.setdefault(Topk_acc, []).append(augi_acck) +# +# # result.setdefault("augt_Top1_acc", []).append(augt_acc1) +# # Topk_acc = "augt_Top" + str(topk) + "_acc" +# # result.setdefault(Topk_acc, []).append(augt_acck) +# +# for key, acc in result.items(): +# acc = torch.cat(acc) +# acc = acc.float().mean().item() * 100 +# result[key] = acc +# +# +# +# log_info = [f"{k}: {v:.3f}" for k, v in result.items()] +# logger.info(f"[Results] {'; '.join(log_info)}") +# +# # F1_3,P_3,R_3 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) +# # F1_5,P_5,R_5 = compute_F1(logits1.cuda(), target.cuda(), 'overall', k_val=3) +# # logger.info('g_k=3: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_3), torch.mean(P_3), torch.mean(R_3))) +# # logger.info('g_k=5: {:.4f},{:.4f},{:.4f}'.format(torch.mean(F1_5), torch.mean(P_5), torch.mean(R_5)))