lung_nodule_ct_detection / scripts /cocometric_ignite.py
katielink's picture
complete the model package
ac91715
raw
history blame
No virus
4.94 kB
from typing import Callable, Dict, Sequence, Union
import torch
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from monai.apps.detection.metrics.coco import COCOMetric
from monai.apps.detection.metrics.matching import matching_batch
from monai.data import box_utils
from .utils import detach_to_numpy
class IgniteCocoMetric(Metric):
def __init__(
self,
coco_metric_monai: Union[None, COCOMetric] = None,
box_key="box",
label_key="label",
pred_score_key="label_scores",
output_transform: Callable = lambda x: x,
device: Union[str, torch.device, None] = None,
reduce_scalar: bool = True,
):
r"""
Computes coco detection metric in Ignite.
Args:
coco_metric_monai: the coco metric in monai.
If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
box_key: box key in the ground truth target dict and prediction dict.
label_key: classification label key in the ground truth target dict and prediction dict.
pred_score_key: classification score key in the prediction dict.
output_transform: A callable that is used to transform the Engine’s
process_function’s output into the form expected by the metric.
device: specifies which device updates are accumulated on.
Setting the metric’s device to be the same as your update arguments ensures
the update method is non-blocking. By default, CPU.
reduce_scalar: if True, will return the average value of coc metric values;
if False, will return an dictionary of coc metric.
Examples:
To use with ``Engine`` and ``process_function``,
simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`,
visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
coco = IgniteCocoMetric()
coco.attach(default_evaluator, 'coco')
preds = [
{
'box': torch.Tensor([[1,1,1,2,2,2]]),
'label':torch.Tensor([0]),
'label_scores':torch.Tensor([0.8])
}
]
target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}]
state = default_evaluator.run([[preds, target]])
print(state.metrics['coco'])
.. testoutput::
1.0...
.. versionadded:: 0.4.3
"""
self.box_key = box_key
self.label_key = label_key
self.pred_score_key = pred_score_key
if coco_metric_monai is None:
self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
else:
self.coco_metric = coco_metric_monai
self.reduce_scalar = reduce_scalar
if device is None:
device = torch.device("cpu")
super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device)
@reinit__is_reduced
def reset(self) -> None:
self.val_targets_all = []
self.val_outputs_all = []
@reinit__is_reduced
def update(self, output: Sequence[Dict]) -> None:
y_pred, y = output[0], output[1]
self.val_outputs_all += y_pred
self.val_targets_all += y
@sync_all_reduce("val_targets_all", "val_outputs_all")
def compute(self) -> float:
self.val_outputs_all = detach_to_numpy(self.val_outputs_all)
self.val_targets_all = detach_to_numpy(self.val_targets_all)
results_metric = matching_batch(
iou_fn=box_utils.box_iou,
iou_thresholds=self.coco_metric.iou_thresholds,
pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all],
pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all],
pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all],
gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all],
gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all],
)
val_epoch_metric_dict = self.coco_metric(results_metric)[0]
if self.reduce_scalar:
val_epoch_metric = val_epoch_metric_dict.values()
val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric)
return val_epoch_metric
else:
return val_epoch_metric_dict