from typing import Callable, Dict, Sequence, Union | |
import torch | |
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | |
from monai.apps.detection.metrics.coco import COCOMetric | |
from monai.apps.detection.metrics.matching import matching_batch | |
from monai.data import box_utils | |
from .utils import detach_to_numpy | |
class IgniteCocoMetric(Metric): | |
def __init__( | |
self, | |
coco_metric_monai: Union[None, COCOMetric] = None, | |
box_key="box", | |
label_key="label", | |
pred_score_key="label_scores", | |
output_transform: Callable = lambda x: x, | |
device: Union[str, torch.device, None] = None, | |
reduce_scalar: bool = True, | |
): | |
r""" | |
Computes coco detection metric in Ignite. | |
Args: | |
coco_metric_monai: the coco metric in monai. | |
If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100]) | |
box_key: box key in the ground truth target dict and prediction dict. | |
label_key: classification label key in the ground truth target dict and prediction dict. | |
pred_score_key: classification score key in the prediction dict. | |
output_transform: A callable that is used to transform the Engine’s | |
process_function’s output into the form expected by the metric. | |
device: specifies which device updates are accumulated on. | |
Setting the metric’s device to be the same as your update arguments ensures | |
the update method is non-blocking. By default, CPU. | |
reduce_scalar: if True, will return the average value of coc metric values; | |
if False, will return an dictionary of coc metric. | |
Examples: | |
To use with ``Engine`` and ``process_function``, | |
simply attach the metric instance to the engine. | |
The output of the engine's ``process_function`` needs to be in format of | |
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. | |
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, | |
visit :ref:`attach-engine`. | |
.. include:: defaults.rst | |
:start-after: :orphan: | |
.. testcode:: | |
coco = IgniteCocoMetric() | |
coco.attach(default_evaluator, 'coco') | |
preds = [ | |
{ | |
'box': torch.Tensor([[1,1,1,2,2,2]]), | |
'label':torch.Tensor([0]), | |
'label_scores':torch.Tensor([0.8]) | |
} | |
] | |
target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}] | |
state = default_evaluator.run([[preds, target]]) | |
print(state.metrics['coco']) | |
.. testoutput:: | |
1.0... | |
.. versionadded:: 0.4.3 | |
""" | |
self.box_key = box_key | |
self.label_key = label_key | |
self.pred_score_key = pred_score_key | |
if coco_metric_monai is None: | |
self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100]) | |
else: | |
self.coco_metric = coco_metric_monai | |
self.reduce_scalar = reduce_scalar | |
if device is None: | |
device = torch.device("cpu") | |
super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device) | |
def reset(self) -> None: | |
self.val_targets_all = [] | |
self.val_outputs_all = [] | |
def update(self, output: Sequence[Dict]) -> None: | |
y_pred, y = output[0], output[1] | |
self.val_outputs_all += y_pred | |
self.val_targets_all += y | |
def compute(self) -> float: | |
self.val_outputs_all = detach_to_numpy(self.val_outputs_all) | |
self.val_targets_all = detach_to_numpy(self.val_targets_all) | |
results_metric = matching_batch( | |
iou_fn=box_utils.box_iou, | |
iou_thresholds=self.coco_metric.iou_thresholds, | |
pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all], | |
pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all], | |
pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all], | |
gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all], | |
gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all], | |
) | |
val_epoch_metric_dict = self.coco_metric(results_metric)[0] | |
if self.reduce_scalar: | |
val_epoch_metric = val_epoch_metric_dict.values() | |
val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric) | |
return val_epoch_metric | |
else: | |
return val_epoch_metric_dict | |