obichimav's picture
Upload 42 files
8e5d8c7 verified
from typing import Dict, Optional
import numpy as np
def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool,
label_mapping: Optional[Dict[int, int]] = None,
reduce_labels: bool = False):
"""Computes intersection and union for IoU calculation."""
if label_mapping:
for old_id, new_id in label_mapping.items():
ground_truth[ground_truth == old_id] = new_id
prediction = np.array(prediction)
ground_truth = np.array(ground_truth)
if reduce_labels:
ground_truth[ground_truth == 0] = 255
ground_truth = ground_truth - 1
ground_truth[ground_truth == 254] = 255
valid_mask = np.not_equal(ground_truth, ignore_index)
prediction = prediction[valid_mask]
ground_truth = ground_truth[valid_mask]
intersection_mask = prediction == ground_truth
intersection = prediction[intersection_mask]
area_intersection = np.histogram(intersection, bins=num_classes,
range=(0, num_classes - 1))[0]
area_prediction = np.histogram(prediction, bins=num_classes,
range=(0, num_classes - 1))[0]
area_ground_truth = np.histogram(ground_truth, bins=num_classes,
range=(0, num_classes - 1))[0]
area_union = area_prediction + area_ground_truth - area_intersection
return area_intersection, area_union, area_prediction, area_ground_truth
def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool,
label_mapping: Optional[Dict[int, int]] = None,
reduce_labels: bool = False):
"""Computes total intersection and union across all samples."""
totals = {
'intersection': np.zeros((num_classes,), dtype=np.float64),
'union': np.zeros((num_classes,), dtype=np.float64),
'prediction': np.zeros((num_classes,), dtype=np.float64),
'ground_truth': np.zeros((num_classes,), dtype=np.float64)
}
for pred, gt in zip(predictions, ground_truths):
intersection, union, pred_area, gt_area = compute_intersection_union(
pred, gt, num_classes, ignore_index, label_mapping, reduce_labels
)
totals['intersection'] += intersection
totals['union'] += union
totals['prediction'] += pred_area
totals['ground_truth'] += gt_area
return tuple(totals.values())
def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool,
nan_to_num: Optional[int] = None,
label_mapping: Optional[Dict[int, int]] = None,
reduce_labels: bool = False):
"""Computes mean IoU and related metrics."""
intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union(
predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels
)
metrics = {}
# Compute overall accuracy
total_accuracy = intersection.sum() / ground_truth_area.sum()
# Compute IoU per class
iou_per_class = intersection / union
accuracy_per_class = intersection / ground_truth_area
metrics.update({
"mean_iou": np.nanmean(iou_per_class),
"mean_accuracy": np.nanmean(accuracy_per_class),
"overall_accuracy": total_accuracy,
"per_category_iou": iou_per_class,
"per_category_accuracy": accuracy_per_class
})
if nan_to_num is not None:
metrics = {
metric: np.nan_to_num(value, nan=nan_to_num)
for metric, value in metrics.items()
}
return metrics