|
from typing import Dict, Optional |
|
import numpy as np |
|
|
|
def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool, |
|
label_mapping: Optional[Dict[int, int]] = None, |
|
reduce_labels: bool = False): |
|
"""Computes intersection and union for IoU calculation.""" |
|
|
|
if label_mapping: |
|
for old_id, new_id in label_mapping.items(): |
|
ground_truth[ground_truth == old_id] = new_id |
|
|
|
prediction = np.array(prediction) |
|
ground_truth = np.array(ground_truth) |
|
|
|
if reduce_labels: |
|
ground_truth[ground_truth == 0] = 255 |
|
ground_truth = ground_truth - 1 |
|
ground_truth[ground_truth == 254] = 255 |
|
|
|
valid_mask = np.not_equal(ground_truth, ignore_index) |
|
prediction = prediction[valid_mask] |
|
ground_truth = ground_truth[valid_mask] |
|
|
|
intersection_mask = prediction == ground_truth |
|
intersection = prediction[intersection_mask] |
|
|
|
area_intersection = np.histogram(intersection, bins=num_classes, |
|
range=(0, num_classes - 1))[0] |
|
area_prediction = np.histogram(prediction, bins=num_classes, |
|
range=(0, num_classes - 1))[0] |
|
area_ground_truth = np.histogram(ground_truth, bins=num_classes, |
|
range=(0, num_classes - 1))[0] |
|
area_union = area_prediction + area_ground_truth - area_intersection |
|
|
|
return area_intersection, area_union, area_prediction, area_ground_truth |
|
|
|
def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool, |
|
label_mapping: Optional[Dict[int, int]] = None, |
|
reduce_labels: bool = False): |
|
"""Computes total intersection and union across all samples.""" |
|
|
|
totals = { |
|
'intersection': np.zeros((num_classes,), dtype=np.float64), |
|
'union': np.zeros((num_classes,), dtype=np.float64), |
|
'prediction': np.zeros((num_classes,), dtype=np.float64), |
|
'ground_truth': np.zeros((num_classes,), dtype=np.float64) |
|
} |
|
|
|
for pred, gt in zip(predictions, ground_truths): |
|
intersection, union, pred_area, gt_area = compute_intersection_union( |
|
pred, gt, num_classes, ignore_index, label_mapping, reduce_labels |
|
) |
|
totals['intersection'] += intersection |
|
totals['union'] += union |
|
totals['prediction'] += pred_area |
|
totals['ground_truth'] += gt_area |
|
|
|
return tuple(totals.values()) |
|
|
|
def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool, |
|
nan_to_num: Optional[int] = None, |
|
label_mapping: Optional[Dict[int, int]] = None, |
|
reduce_labels: bool = False): |
|
"""Computes mean IoU and related metrics.""" |
|
|
|
intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union( |
|
predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels |
|
) |
|
|
|
metrics = {} |
|
|
|
|
|
total_accuracy = intersection.sum() / ground_truth_area.sum() |
|
|
|
|
|
iou_per_class = intersection / union |
|
accuracy_per_class = intersection / ground_truth_area |
|
|
|
metrics.update({ |
|
"mean_iou": np.nanmean(iou_per_class), |
|
"mean_accuracy": np.nanmean(accuracy_per_class), |
|
"overall_accuracy": total_accuracy, |
|
"per_category_iou": iou_per_class, |
|
"per_category_accuracy": accuracy_per_class |
|
}) |
|
|
|
if nan_to_num is not None: |
|
metrics = { |
|
metric: np.nan_to_num(value, nan=nan_to_num) |
|
for metric, value in metrics.items() |
|
} |
|
|
|
return metrics |