from typing import Callable, Dict, Sequence, Union import torch from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce from monai.apps.detection.metrics.coco import COCOMetric from monai.apps.detection.metrics.matching import matching_batch from monai.data import box_utils from .utils import detach_to_numpy class IgniteCocoMetric(Metric): def __init__( self, coco_metric_monai: Union[None, COCOMetric] = None, box_key="box", label_key="label", pred_score_key="label_scores", output_transform: Callable = lambda x: x, device: Union[str, torch.device, None] = None, reduce_scalar: bool = True, ): r""" Computes coco detection metric in Ignite. Args: coco_metric_monai: the coco metric in monai. If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100]) box_key: box key in the ground truth target dict and prediction dict. label_key: classification label key in the ground truth target dict and prediction dict. pred_score_key: classification score key in the prediction dict. output_transform: A callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. device: specifies which device updates are accumulated on. Setting the metric’s device to be the same as your update arguments ensures the update method is non-blocking. By default, CPU. reduce_scalar: if True, will return the average value of coc metric values; if False, will return an dictionary of coc metric. Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. The output of the engine's ``process_function`` needs to be in format of ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. include:: defaults.rst :start-after: :orphan: .. testcode:: coco = IgniteCocoMetric() coco.attach(default_evaluator, 'coco') preds = [ { 'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0]), 'label_scores':torch.Tensor([0.8]) } ] target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}] state = default_evaluator.run([[preds, target]]) print(state.metrics['coco']) .. testoutput:: 1.0... .. versionadded:: 0.4.3 """ self.box_key = box_key self.label_key = label_key self.pred_score_key = pred_score_key if coco_metric_monai is None: self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100]) else: self.coco_metric = coco_metric_monai self.reduce_scalar = reduce_scalar if device is None: device = torch.device("cpu") super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced def reset(self) -> None: self.val_targets_all = [] self.val_outputs_all = [] @reinit__is_reduced def update(self, output: Sequence[Dict]) -> None: y_pred, y = output[0], output[1] self.val_outputs_all += y_pred self.val_targets_all += y @sync_all_reduce("val_targets_all", "val_outputs_all") def compute(self) -> float: self.val_outputs_all = detach_to_numpy(self.val_outputs_all) self.val_targets_all = detach_to_numpy(self.val_targets_all) results_metric = matching_batch( iou_fn=box_utils.box_iou, iou_thresholds=self.coco_metric.iou_thresholds, pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all], pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all], pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all], gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all], gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all], ) val_epoch_metric_dict = self.coco_metric(results_metric)[0] if self.reduce_scalar: val_epoch_metric = val_epoch_metric_dict.values() val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric) return val_epoch_metric else: return val_epoch_metric_dict