""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import json import logging import os import numpy as np import torch from lavis.common.dist_utils import is_main_process from lavis.common.registry import registry from lavis.tasks.base_task import BaseTask @registry.register_task("retrieval") class RetrievalTask(BaseTask): def __init__(self, cfg): super().__init__() self.cfg = cfg @classmethod def setup_task(cls, cfg): run_cfg = cfg.run_cfg return cls(cfg=run_cfg) def evaluation(self, model, data_loader, **kwargs): # score_i2t, score_t2i = model.compute_sim_matrix(model, data_loader) score_i2t, score_t2i = model.compute_sim_matrix(data_loader, task_cfg=self.cfg) if is_main_process(): eval_result = self._report_metrics( score_i2t, score_t2i, data_loader.dataset.txt2img, data_loader.dataset.img2txt, ) logging.info(eval_result) else: eval_result = None return eval_result def after_evaluation(self, val_result, **kwargs): return val_result @staticmethod @torch.no_grad() def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt): # Images->Text ranks = np.zeros(scores_i2t.shape[0]) for index, score in enumerate(scores_i2t): inds = np.argsort(score)[::-1] # Score rank = 1e20 for i in img2txt[index]: tmp = np.where(inds == i)[0][0] if tmp < rank: rank = tmp ranks[index] = rank # Compute metrics tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) # Text->Images ranks = np.zeros(scores_t2i.shape[0]) for index, score in enumerate(scores_t2i): inds = np.argsort(score)[::-1] ranks[index] = np.where(inds == txt2img[index])[0][0] # Compute metrics ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) tr_mean = (tr1 + tr5 + tr10) / 3 ir_mean = (ir1 + ir5 + ir10) / 3 r_mean = (tr_mean + ir_mean) / 2 agg_metrics = (tr1 + tr5 + tr10) / 3 eval_result = { "txt_r1": tr1, "txt_r5": tr5, "txt_r10": tr10, "txt_r_mean": tr_mean, "img_r1": ir1, "img_r5": ir5, "img_r10": ir10, "img_r_mean": ir_mean, "r_mean": r_mean, "agg_metrics": agg_metrics, } with open( os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" ) as f: f.write(json.dumps(eval_result) + "\n") return eval_result