Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import json | |
import os | |
from lavis.common.dist_utils import main_process | |
from lavis.common.logger import MetricLogger | |
from lavis.common.registry import registry | |
from lavis.tasks.base_task import BaseTask | |
from lavis.datasets.data_utils import prepare_sample | |
import numpy as np | |
class DialogueTask(BaseTask): | |
def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True): | |
super().__init__() | |
self.num_beams = num_beams | |
self.max_len = max_len | |
self.min_len = min_len | |
self.evaluate = evaluate | |
self.report_metric = report_metric | |
def setup_task(cls, cfg): | |
run_cfg = cfg.run_cfg | |
num_beams = run_cfg.num_beams | |
max_len = run_cfg.max_len | |
min_len = run_cfg.min_len | |
evaluate = run_cfg.evaluate | |
report_metric = run_cfg.get("report_metric", True) | |
return cls( | |
num_beams=num_beams, | |
max_len=max_len, | |
min_len=min_len, | |
evaluate=evaluate, | |
report_metric=report_metric, | |
) | |
def valid_step(self, model, samples): | |
results = [] | |
loss = model(samples)["loss"].item() | |
return [loss] | |
def after_evaluation(self, val_result, split_name, epoch, **kwargs): | |
if self.report_metric: | |
avg_loss = np.mean(val_result) | |
metrics = {"agg_metrics": avg_loss} | |
else: | |
metrics = {"agg_metrics": 0.0} | |
return metrics | |
def _report_metrics(self, eval_result_file, split_name): | |
# TODO better way to define this | |
coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt") | |
coco_val = coco_dialogue_eval(coco_gt_root, eval_result_file, split_name) | |
agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"] | |
log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}} | |
with open( | |
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(log_stats) + "\n") | |
coco_res = {k: v for k, v in coco_val.eval.items()} | |
coco_res["agg_metrics"] = agg_metrics | |
return coco_res | |
# TODO better structure for this. | |
from pycocoevalcap.eval import COCOEvalCap | |
from pycocotools.coco import COCO | |
from torchvision.datasets.utils import download_url | |
def coco_dialogue_eval(coco_gt_root, results_file, split): | |
urls = { | |
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json", | |
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json", | |
} | |
filenames = { | |
"val": "coco_karpathy_val_gt.json", | |
"test": "coco_karpathy_test_gt.json", | |
} | |
download_url(urls[split], coco_gt_root) | |
annotation_file = os.path.join(coco_gt_root, filenames[split]) | |
# create coco object and coco_result object | |
coco = COCO(annotation_file) | |
coco_result = coco.loadRes(results_file) | |
# create coco_eval object by taking coco and coco_result | |
coco_eval = COCOEvalCap(coco, coco_result) | |
# evaluate on a subset of images by setting | |
# coco_eval.params['image_id'] = coco_result.getImgIds() | |
# please remove this line when evaluating the full validation set | |
# coco_eval.params['image_id'] = coco_result.getImgIds() | |
# evaluate results | |
# SPICE will take a few minutes the first time, but speeds up due to caching | |
coco_eval.evaluate() | |
# print output evaluation scores | |
for metric, score in coco_eval.eval.items(): | |
print(f"{metric}: {score:.3f}") | |
return coco_eval | |