""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import json import os from lavis.common.dist_utils import main_process from lavis.common.logger import MetricLogger from lavis.common.registry import registry from lavis.tasks.base_task import BaseTask from lavis.datasets.data_utils import prepare_sample import numpy as np @registry.register_task("dialogue") class DialogueTask(BaseTask): def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True): super().__init__() self.num_beams = num_beams self.max_len = max_len self.min_len = min_len self.evaluate = evaluate self.report_metric = report_metric @classmethod def setup_task(cls, cfg): run_cfg = cfg.run_cfg num_beams = run_cfg.num_beams max_len = run_cfg.max_len min_len = run_cfg.min_len evaluate = run_cfg.evaluate report_metric = run_cfg.get("report_metric", True) return cls( num_beams=num_beams, max_len=max_len, min_len=min_len, evaluate=evaluate, report_metric=report_metric, ) def valid_step(self, model, samples): results = [] loss = model(samples)["loss"].item() return [loss] def after_evaluation(self, val_result, split_name, epoch, **kwargs): if self.report_metric: avg_loss = np.mean(val_result) metrics = {"agg_metrics": avg_loss} else: metrics = {"agg_metrics": 0.0} return metrics @main_process def _report_metrics(self, eval_result_file, split_name): # TODO better way to define this coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt") coco_val = coco_dialogue_eval(coco_gt_root, eval_result_file, split_name) agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"] log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}} with open( os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" ) as f: f.write(json.dumps(log_stats) + "\n") coco_res = {k: v for k, v in coco_val.eval.items()} coco_res["agg_metrics"] = agg_metrics return coco_res # TODO better structure for this. from pycocoevalcap.eval import COCOEvalCap from pycocotools.coco import COCO from torchvision.datasets.utils import download_url def coco_dialogue_eval(coco_gt_root, results_file, split): urls = { "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json", "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json", } filenames = { "val": "coco_karpathy_val_gt.json", "test": "coco_karpathy_test_gt.json", } download_url(urls[split], coco_gt_root) annotation_file = os.path.join(coco_gt_root, filenames[split]) # create coco object and coco_result object coco = COCO(annotation_file) coco_result = coco.loadRes(results_file) # create coco_eval object by taking coco and coco_result coco_eval = COCOEvalCap(coco, coco_result) # evaluate on a subset of images by setting # coco_eval.params['image_id'] = coco_result.getImgIds() # please remove this line when evaluating the full validation set # coco_eval.params['image_id'] = coco_result.getImgIds() # evaluate results # SPICE will take a few minutes the first time, but speeds up due to caching coco_eval.evaluate() # print output evaluation scores for metric, score in coco_eval.eval.items(): print(f"{metric}: {score:.3f}") return coco_eval