from typing import Dict, Optional import numpy as np def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool, label_mapping: Optional[Dict[int, int]] = None, reduce_labels: bool = False): """Computes intersection and union for IoU calculation.""" if label_mapping: for old_id, new_id in label_mapping.items(): ground_truth[ground_truth == old_id] = new_id prediction = np.array(prediction) ground_truth = np.array(ground_truth) if reduce_labels: ground_truth[ground_truth == 0] = 255 ground_truth = ground_truth - 1 ground_truth[ground_truth == 254] = 255 valid_mask = np.not_equal(ground_truth, ignore_index) prediction = prediction[valid_mask] ground_truth = ground_truth[valid_mask] intersection_mask = prediction == ground_truth intersection = prediction[intersection_mask] area_intersection = np.histogram(intersection, bins=num_classes, range=(0, num_classes - 1))[0] area_prediction = np.histogram(prediction, bins=num_classes, range=(0, num_classes - 1))[0] area_ground_truth = np.histogram(ground_truth, bins=num_classes, range=(0, num_classes - 1))[0] area_union = area_prediction + area_ground_truth - area_intersection return area_intersection, area_union, area_prediction, area_ground_truth def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool, label_mapping: Optional[Dict[int, int]] = None, reduce_labels: bool = False): """Computes total intersection and union across all samples.""" totals = { 'intersection': np.zeros((num_classes,), dtype=np.float64), 'union': np.zeros((num_classes,), dtype=np.float64), 'prediction': np.zeros((num_classes,), dtype=np.float64), 'ground_truth': np.zeros((num_classes,), dtype=np.float64) } for pred, gt in zip(predictions, ground_truths): intersection, union, pred_area, gt_area = compute_intersection_union( pred, gt, num_classes, ignore_index, label_mapping, reduce_labels ) totals['intersection'] += intersection totals['union'] += union totals['prediction'] += pred_area totals['ground_truth'] += gt_area return tuple(totals.values()) def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool, nan_to_num: Optional[int] = None, label_mapping: Optional[Dict[int, int]] = None, reduce_labels: bool = False): """Computes mean IoU and related metrics.""" intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union( predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels ) metrics = {} # Compute overall accuracy total_accuracy = intersection.sum() / ground_truth_area.sum() # Compute IoU per class iou_per_class = intersection / union accuracy_per_class = intersection / ground_truth_area metrics.update({ "mean_iou": np.nanmean(iou_per_class), "mean_accuracy": np.nanmean(accuracy_per_class), "overall_accuracy": total_accuracy, "per_category_iou": iou_per_class, "per_category_accuracy": accuracy_per_class }) if nan_to_num is not None: metrics = { metric: np.nan_to_num(value, nan=nan_to_num) for metric, value in metrics.items() } return metrics