Robert001's picture
first commit
b334e29
raw
history blame
No virus
13.1 kB
from collections import OrderedDict
import annotator.uniformer.mmcv as mmcv
import numpy as np
import torch
def f_score(precision, recall, beta=1):
"""calcuate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
[torch.tensor]: The f-score value.
"""
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score
def intersect_and_union(pred_label,
label,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate intersection and Union.
Args:
pred_label (ndarray | str): Prediction segmentation map
or predict result filename.
label (ndarray | str): Ground truth segmentation map
or label filename.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
work only when label is str. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. The parameter will
work only when label is str. Default: False.
Returns:
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
if isinstance(pred_label, str):
pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))
if isinstance(label, str):
label = torch.from_numpy(
mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
def total_intersect_and_union(results,
gt_seg_maps,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Total Intersection and Union.
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(
results[i], gt_seg_maps[i], num_classes, ignore_index,
label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label
def mean_iou(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
dict[str, float | ndarray]:
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<IoU> ndarray: Per category IoU, shape (num_classes, ).
"""
iou_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mIoU'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return iou_result
def mean_dice(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Dice (mDice)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<Dice> ndarray: Per category dice, shape (num_classes, ).
"""
dice_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mDice'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return dice_result
def mean_fscore(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False,
beta=1):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Fscore> ndarray: Per category recall, shape (num_classes, ).
<Precision> ndarray: Per category precision, shape (num_classes, ).
<Recall> ndarray: Per category f-score, shape (num_classes, ).
"""
fscore_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mFscore'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label,
beta=beta)
return fscore_result
def eval_metrics(results,
gt_seg_maps,
num_classes,
ignore_index,
metrics=['mIoU'],
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False,
beta=1):
"""Calculate evaluation metrics
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics