File size: 3,753 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from typing import Dict, Optional
import numpy as np

def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool,
                             label_mapping: Optional[Dict[int, int]] = None,
                             reduce_labels: bool = False):
    """Computes intersection and union for IoU calculation."""
    
    if label_mapping:
        for old_id, new_id in label_mapping.items():
            ground_truth[ground_truth == old_id] = new_id

    prediction = np.array(prediction)
    ground_truth = np.array(ground_truth)

    if reduce_labels:
        ground_truth[ground_truth == 0] = 255
        ground_truth = ground_truth - 1
        ground_truth[ground_truth == 254] = 255

    valid_mask = np.not_equal(ground_truth, ignore_index)
    prediction = prediction[valid_mask]
    ground_truth = ground_truth[valid_mask]

    intersection_mask = prediction == ground_truth
    intersection = prediction[intersection_mask]

    area_intersection = np.histogram(intersection, bins=num_classes, 
                                   range=(0, num_classes - 1))[0]
    area_prediction = np.histogram(prediction, bins=num_classes, 
                                 range=(0, num_classes - 1))[0]
    area_ground_truth = np.histogram(ground_truth, bins=num_classes, 
                                   range=(0, num_classes - 1))[0]
    area_union = area_prediction + area_ground_truth - area_intersection

    return area_intersection, area_union, area_prediction, area_ground_truth

def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool,
                                   label_mapping: Optional[Dict[int, int]] = None,
                                   reduce_labels: bool = False):
    """Computes total intersection and union across all samples."""
    
    totals = {
        'intersection': np.zeros((num_classes,), dtype=np.float64),
        'union': np.zeros((num_classes,), dtype=np.float64),
        'prediction': np.zeros((num_classes,), dtype=np.float64),
        'ground_truth': np.zeros((num_classes,), dtype=np.float64)
    }

    for pred, gt in zip(predictions, ground_truths):
        intersection, union, pred_area, gt_area = compute_intersection_union(
            pred, gt, num_classes, ignore_index, label_mapping, reduce_labels
        )
        totals['intersection'] += intersection
        totals['union'] += union
        totals['prediction'] += pred_area
        totals['ground_truth'] += gt_area

    return tuple(totals.values())

def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool,
                    nan_to_num: Optional[int] = None,
                    label_mapping: Optional[Dict[int, int]] = None,
                    reduce_labels: bool = False):
    """Computes mean IoU and related metrics."""
    
    intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union(
        predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels
    )

    metrics = {}
    
    # Compute overall accuracy
    total_accuracy = intersection.sum() / ground_truth_area.sum()
    
    # Compute IoU per class
    iou_per_class = intersection / union
    accuracy_per_class = intersection / ground_truth_area

    metrics.update({
        "mean_iou": np.nanmean(iou_per_class),
        "mean_accuracy": np.nanmean(accuracy_per_class),
        "overall_accuracy": total_accuracy,
        "per_category_iou": iou_per_class,
        "per_category_accuracy": accuracy_per_class
    })

    if nan_to_num is not None:
        metrics = {
            metric: np.nan_to_num(value, nan=nan_to_num)
            for metric, value in metrics.items()
        }

    return metrics