import copy from collections import defaultdict from pathlib import Path import torch import torch.utils.data import maskrcnn_benchmark.utils.dist as dist from maskrcnn_benchmark.layers.set_loss import generalized_box_iou from .modulated_coco import ModulatedDataset class RefExpDataset(ModulatedDataset): pass class RefExpEvaluator(object): def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5): assert isinstance(k, (list, tuple)) refexp_gt = copy.deepcopy(refexp_gt) self.refexp_gt = refexp_gt self.iou_types = iou_types self.img_ids = self.refexp_gt.imgs.keys() self.predictions = {} self.k = k self.thresh_iou = thresh_iou def accumulate(self): pass def update(self, predictions): self.predictions.update(predictions) def synchronize_between_processes(self): all_predictions = dist.all_gather(self.predictions) merged_predictions = {} for p in all_predictions: merged_predictions.update(p) self.predictions = merged_predictions def summarize(self): if dist.is_main_process(): dataset2score = { "refcoco": {k: 0.0 for k in self.k}, "refcoco+": {k: 0.0 for k in self.k}, "refcocog": {k: 0.0 for k in self.k}, } dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0} for image_id in self.img_ids: ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id) assert len(ann_ids) == 1 img_info = self.refexp_gt.loadImgs(image_id)[0] target = self.refexp_gt.loadAnns(ann_ids[0]) prediction = self.predictions[image_id] assert prediction is not None sorted_scores_boxes = sorted( zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True ) sorted_scores, sorted_boxes = zip(*sorted_scores_boxes) sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes]) target_bbox = target[0]["bbox"] converted_bbox = [ target_bbox[0], target_bbox[1], target_bbox[2] + target_bbox[0], target_bbox[3] + target_bbox[1], ] giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4)) for k in self.k: if max(giou[:k]) >= self.thresh_iou: dataset2score[img_info["dataset_name"]][k] += 1.0 dataset2count[img_info["dataset_name"]] += 1.0 for key, value in dataset2score.items(): for k in self.k: try: value[k] /= dataset2count[key] except: pass results = {} for key, value in dataset2score.items(): results[key] = sorted([v for k, v in value.items()]) print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n") return results return None