# Copyright (c) OpenMMLab. All rights reserved. from typing import List import torch import torchvision.ops.boxes as boxes from mmengine.evaluator import BaseMetric from mmpretrain.registry import METRICS def aligned_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor): area1 = boxes.box_area(boxes1) area2 = boxes.box_area(boxes2) lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (B, 2) rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (B, 2) wh = boxes._upcast(rb - lt).clamp(min=0) # (B, 2) inter = wh[:, 0] * wh[:, 1] # (B, ) union = area1 + area2 - inter iou = inter / union return iou @METRICS.register_module() class VisualGroundingMetric(BaseMetric): """Visual Grounding evaluator. Calculate the box mIOU and box grounding accuracy for visual grounding model. Args: collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Should be modified according to the `retrieval_type` for unambiguous results. Defaults to TR. """ default_prefix = 'visual-grounding' def process(self, data_batch, data_samples): """Process one batch of data samples. The processed results should be stored in ``self.results``, which will be used to computed the metrics when all batches have been processed. Args: data_batch: A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ for preds in data_samples: pred_box = preds['pred_bboxes'].squeeze() box_gt = torch.Tensor(preds['gt_bboxes']).squeeze() result = { 'box': pred_box.to('cpu').squeeze(), 'box_target': box_gt.squeeze(), } self.results.append(result) def compute_metrics(self, results: List): """Compute the metrics from processed results. Args: results (dict): The processed results of each batch. Returns: Dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ pred_boxes = torch.stack([each['box'] for each in results]) gt_boxes = torch.stack([each['box_target'] for each in results]) iou = aligned_box_iou(pred_boxes, gt_boxes) accu_num = torch.sum(iou >= 0.5) miou = torch.mean(iou) acc = accu_num / len(gt_boxes) coco_val = {'miou': miou, 'acc': acc} return coco_val