File size: 6,555 Bytes
6e6d6a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import numpy as np
import time
import datetime
import torch
import torch.nn.functional as F
import torch.distributed as dist
from models import utils
@torch.no_grad()
def evaluation(args, model, data_loader, device):
# test
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Evaluation:'
print('Computing features for evaluation...')
start_time = time.time()
num_tasks = utils.get_world_size()
rank = utils.get_rank()
# ======================================== text feature ======================================== #
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i: min(num_text, i + text_bs)]
text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=65,
return_tensors="pt").to(device)
text_feat = model.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask, mode='text')
text_embed = F.normalize(model.text_proj(text_feat.last_hidden_state[:,0,:]), dim=-1)
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
# ======================================== image&sketch feature ======================================== #
image_feats = []
image_embeds = []
for i, (image, img_id) in enumerate(data_loader):
image = image.to(device)
image_feat = model.visual_encoder(image).last_hidden_state
image_embed = F.normalize(model.vision_proj(image_feat[:,0,:]), dim=-1)
image_feats.append(image_feat.cpu())
image_embeds.append(image_embed)
image_feats = torch.cat(image_feats, dim=0).to(device)
image_embeds = torch.cat(image_embeds, dim=0).to(device)
print('Computing features Cost time {}'.format(time.time() - start_time))
# ======================================== i2t score ======================================== #
sims_matrix = image_embeds @ text_embeds.t()
score_matrix_i2t = torch.full((len(data_loader.dataset.image), len(texts)), -100.0).to(device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
k_test = 256
for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
encoder_output = image_feats[start + i].repeat(k_test, 1, 1).to(device)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
output = model.text_encoder(text_ids[topk_idx],
attention_mask=text_atts[topk_idx],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
)
score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
# ======================================== t2i score ======================================== #
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full((len(texts), len(data_loader.dataset.image)), -100.0).to(device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
encoder_output = image_feats[topk_idx].to(device)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
output = model.text_encoder(text_ids[start + i].repeat(k_test, 1),
attention_mask=text_atts[start + i].repeat(k_test, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
)
score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
score_matrix_t2i[start + i, topk_idx] = topk_sim + score
if args.distributed:
dist.barrier()
torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Evaluation time {}'.format(total_time_str))
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
# Images->Text
ranks = np.zeros(scores_i2t.shape[0])
for index, score in enumerate(scores_i2t):
inds = np.argsort(score)[::-1]
# Score
rank = 1e20
for i in img2txt[index]:
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
# Text->Images
ranks = np.zeros(scores_t2i.shape[0])
for index, score in enumerate(scores_t2i):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == txt2img[index])[0][0]
# Compute metrics
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
tr_mean = (tr1 + tr5 + tr10) / 3
ir_mean = (ir1 + ir5 + ir10) / 3
r_mean = (tr_mean + ir_mean) / 2
eval_result = {
'txt_r1': tr1,
'txt_r5': tr5,
'txt_r10': tr10,
'txt_r_mean': tr_mean,
'img_r1': ir1,
'img_r5': ir5,
'img_r10': ir10,
'img_r_mean': ir_mean,
'r_mean': r_mean}
return eval_result |