""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import json import os import re from lavis.common.dist_utils import main_process from lavis.common.registry import registry from lavis.tasks.base_task import BaseTask @registry.register_task("captioning") class CaptionTask(BaseTask): def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True): super().__init__() self.num_beams = num_beams self.max_len = max_len self.min_len = min_len self.evaluate = evaluate self.report_metric = report_metric @classmethod def setup_task(cls, cfg): run_cfg = cfg.run_cfg num_beams = run_cfg.num_beams max_len = run_cfg.max_len min_len = run_cfg.min_len evaluate = run_cfg.evaluate report_metric = run_cfg.get("report_metric", True) return cls( num_beams=num_beams, max_len=max_len, min_len=min_len, evaluate=evaluate, report_metric=report_metric, ) def valid_step(self, model, samples): results = [] # run_cfg = slf.cfg.run_cfg captions = model.generate( samples, use_nucleus_sampling=False, num_beams=self.num_beams, max_length=self.max_len, min_length=self.min_len, ) #print(samples.keys()) img_ids = samples["image_id"] for caption, img_id in zip(captions, img_ids): img_id = re.sub(r"[A-Za-z]", '', str(img_id)) results.append({"caption": caption, "image_id": int(img_id)}) return results def after_evaluation(self, val_result, split_name, epoch, **kwargs): eval_result_file = self.save_result( result=val_result, result_dir=registry.get_path("result_dir"), filename="{}_epoch{}".format(split_name, epoch), remove_duplicate="image_id", ) if self.report_metric: metrics = self._report_metrics( eval_result_file=eval_result_file, split_name=split_name ) else: metrics = {"agg_metrics": 0.0} return metrics @main_process def _report_metrics(self, eval_result_file, split_name): # TODO better way to define this coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt") coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name) agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"] log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}} with open( os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" ) as f: f.write(json.dumps(log_stats) + "\n") coco_res = {k: v for k, v in coco_val.eval.items()} coco_res["agg_metrics"] = agg_metrics return coco_res # TODO better structure for this. from pycocoevalcap.eval import COCOEvalCap from pycocotools.coco import COCO from torchvision.datasets.utils import download_url def coco_caption_eval(coco_gt_root, results_file, split): urls = { "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json", "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json", } filenames = { "val": "coco_karpathy_val_gt.json", "test": "coco_karpathy_test_gt.json", } download_url(urls[split], coco_gt_root) annotation_file = os.path.join(coco_gt_root, filenames[split]) # create coco object and coco_result object coco = COCO(annotation_file) coco_result = coco.loadRes(results_file) # create coco_eval object by taking coco and coco_result coco_eval = COCOEvalCap(coco, coco_result) # evaluate on a subset of images by setting # coco_eval.params['image_id'] = coco_result.getImgIds() # please remove this line when evaluating the full validation set # coco_eval.params['image_id'] = coco_result.getImgIds() # evaluate results # SPICE will take a few minutes the first time, but speeds up due to caching coco_eval.evaluate() # print output evaluation scores for metric, score in coco_eval.eval.items(): print(f"{metric}: {score:.3f}") return coco_eval