|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from lavis.common.dist_utils import is_main_process
|
|
from lavis.common.registry import registry
|
|
from lavis.tasks.base_task import BaseTask
|
|
|
|
|
|
@registry.register_task("retrieval")
|
|
class RetrievalTask(BaseTask):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
self.cfg = cfg
|
|
|
|
@classmethod
|
|
def setup_task(cls, cfg):
|
|
run_cfg = cfg.run_cfg
|
|
|
|
return cls(cfg=run_cfg)
|
|
|
|
def evaluation(self, model, data_loader, **kwargs):
|
|
|
|
score_i2t, score_t2i = model.compute_sim_matrix(data_loader, task_cfg=self.cfg)
|
|
|
|
if is_main_process():
|
|
eval_result = self._report_metrics(
|
|
score_i2t,
|
|
score_t2i,
|
|
data_loader.dataset.txt2img,
|
|
data_loader.dataset.img2txt,
|
|
)
|
|
logging.info(eval_result)
|
|
else:
|
|
eval_result = None
|
|
|
|
return eval_result
|
|
|
|
def after_evaluation(self, val_result, **kwargs):
|
|
return val_result
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt):
|
|
|
|
|
|
ranks = np.zeros(scores_i2t.shape[0])
|
|
for index, score in enumerate(scores_i2t):
|
|
inds = np.argsort(score)[::-1]
|
|
|
|
rank = 1e20
|
|
for i in img2txt[index]:
|
|
tmp = np.where(inds == i)[0][0]
|
|
if tmp < rank:
|
|
rank = tmp
|
|
ranks[index] = rank
|
|
|
|
|
|
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
|
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
|
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
|
|
|
|
|
ranks = np.zeros(scores_t2i.shape[0])
|
|
|
|
for index, score in enumerate(scores_t2i):
|
|
inds = np.argsort(score)[::-1]
|
|
ranks[index] = np.where(inds == txt2img[index])[0][0]
|
|
|
|
|
|
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
|
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
|
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
|
|
|
tr_mean = (tr1 + tr5 + tr10) / 3
|
|
ir_mean = (ir1 + ir5 + ir10) / 3
|
|
r_mean = (tr_mean + ir_mean) / 2
|
|
|
|
agg_metrics = (tr1 + tr5 + tr10) / 3
|
|
|
|
eval_result = {
|
|
"txt_r1": tr1,
|
|
"txt_r5": tr5,
|
|
"txt_r10": tr10,
|
|
"txt_r_mean": tr_mean,
|
|
"img_r1": ir1,
|
|
"img_r5": ir5,
|
|
"img_r10": ir10,
|
|
"img_r_mean": ir_mean,
|
|
"r_mean": r_mean,
|
|
"agg_metrics": agg_metrics,
|
|
}
|
|
with open(
|
|
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
|
|
) as f:
|
|
f.write(json.dumps(eval_result) + "\n")
|
|
return eval_result
|
|
|