|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
from lavis.common.dist_utils import main_process
|
|
from lavis.common.registry import registry
|
|
from lavis.tasks.base_task import BaseTask
|
|
|
|
|
|
@registry.register_task("captioning")
|
|
class CaptionTask(BaseTask):
|
|
def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
|
|
super().__init__()
|
|
|
|
self.num_beams = num_beams
|
|
self.max_len = max_len
|
|
self.min_len = min_len
|
|
self.evaluate = evaluate
|
|
|
|
self.report_metric = report_metric
|
|
|
|
@classmethod
|
|
def setup_task(cls, cfg):
|
|
run_cfg = cfg.run_cfg
|
|
|
|
num_beams = run_cfg.num_beams
|
|
max_len = run_cfg.max_len
|
|
min_len = run_cfg.min_len
|
|
evaluate = run_cfg.evaluate
|
|
|
|
report_metric = run_cfg.get("report_metric", True)
|
|
|
|
return cls(
|
|
num_beams=num_beams,
|
|
max_len=max_len,
|
|
min_len=min_len,
|
|
evaluate=evaluate,
|
|
report_metric=report_metric,
|
|
)
|
|
|
|
def valid_step(self, model, samples):
|
|
results = []
|
|
|
|
|
|
captions = model.generate(
|
|
samples,
|
|
use_nucleus_sampling=False,
|
|
num_beams=self.num_beams,
|
|
max_length=self.max_len,
|
|
min_length=self.min_len,
|
|
)
|
|
|
|
img_ids = samples["image_id"]
|
|
|
|
for caption, img_id in zip(captions, img_ids):
|
|
img_id = re.sub(r"[A-Za-z]", '', str(img_id))
|
|
results.append({"caption": caption, "image_id": int(img_id)})
|
|
|
|
return results
|
|
|
|
def after_evaluation(self, val_result, split_name, epoch, **kwargs):
|
|
eval_result_file = self.save_result(
|
|
result=val_result,
|
|
result_dir=registry.get_path("result_dir"),
|
|
filename="{}_epoch{}".format(split_name, epoch),
|
|
remove_duplicate="image_id",
|
|
)
|
|
|
|
if self.report_metric:
|
|
metrics = self._report_metrics(
|
|
eval_result_file=eval_result_file, split_name=split_name
|
|
)
|
|
else:
|
|
metrics = {"agg_metrics": 0.0}
|
|
|
|
return metrics
|
|
|
|
@main_process
|
|
def _report_metrics(self, eval_result_file, split_name):
|
|
|
|
|
|
coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
|
|
coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name)
|
|
|
|
agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
|
|
log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
|
|
|
|
with open(
|
|
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
|
|
) as f:
|
|
f.write(json.dumps(log_stats) + "\n")
|
|
|
|
coco_res = {k: v for k, v in coco_val.eval.items()}
|
|
coco_res["agg_metrics"] = agg_metrics
|
|
|
|
return coco_res
|
|
|
|
|
|
|
|
from pycocoevalcap.eval import COCOEvalCap
|
|
from pycocotools.coco import COCO
|
|
from torchvision.datasets.utils import download_url
|
|
|
|
|
|
def coco_caption_eval(coco_gt_root, results_file, split):
|
|
urls = {
|
|
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json",
|
|
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json",
|
|
}
|
|
filenames = {
|
|
"val": "coco_karpathy_val_gt.json",
|
|
"test": "coco_karpathy_test_gt.json",
|
|
}
|
|
|
|
download_url(urls[split], coco_gt_root)
|
|
annotation_file = os.path.join(coco_gt_root, filenames[split])
|
|
|
|
|
|
coco = COCO(annotation_file)
|
|
coco_result = coco.loadRes(results_file)
|
|
|
|
|
|
coco_eval = COCOEvalCap(coco, coco_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
coco_eval.evaluate()
|
|
|
|
|
|
for metric, score in coco_eval.eval.items():
|
|
print(f"{metric}: {score:.3f}")
|
|
|
|
return coco_eval
|
|
|