File size: 3,786 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import cv2
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from segmentation_models_pytorch.base.modules import Activation

from SemanticModel.data_loader import SegmentationDataset
from SemanticModel.metrics import compute_mean_iou
from SemanticModel.image_preprocessing import get_validation_augmentations

def evaluate_model(model_config, data_path, image_size=None):
    """Evaluates model performance on a dataset."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
    
    data_path = os.path.realpath(data_path)
    image_subdir = os.path.join(data_path, 'Images')
    mask_subdir = os.path.join(data_path, 'Masks')
    
    if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]):
        raise Exception("Missing required subdirectories: 'Images' and 'Masks'")
        
    if not image_size:
        sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0]))
        height, width = sample_image.shape[:2]
        image_size = max(height, width)
    
    evaluation_dataset = SegmentationDataset(
        data_path, 
        classes=classes,
        augmentation=get_validation_augmentations(
            im_width=image_size, 
            im_height=image_size,
            fixed_size=False
        ),
        preprocessing=model_config.preprocessing
    )
    
    evaluation_loader = DataLoader(
        evaluation_dataset, 
        batch_size=1, 
        shuffle=False, 
        num_workers=2
    )
    
    model = model_config.model.to(device)
    model.eval()
    
    requires_sigmoid = False
    if model_config.n_classes == 1:
        current_activation = _check_activation_function(model)
        if current_activation != 'Sigmoid':
            requires_sigmoid = True
    
    predictions = []
    ground_truth = []
    
    print("Evaluating model performance...")
    with torch.no_grad():
        for images, masks in tqdm(evaluation_loader):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model.forward(images)
            
            if model_config.n_classes > 1:
                predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
                ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks])
            else:
                if requires_sigmoid:
                    predictions.extend([
                        (torch.sigmoid(p) > 0.5).float().squeeze().cpu() 
                        for p in outputs
                    ])
                else:
                    predictions.extend([
                        (p > 0.5).float().squeeze().cpu() 
                        for p in outputs
                    ])
                ground_truth.extend([gt.cpu().squeeze() for gt in masks])
    
    metrics = compute_mean_iou(
        predictions, 
        ground_truth, 
        num_labels=len(classes), 
        ignore_index=255
    )
    
    print("\nEvaluation Results:")
    print(f"Mean IoU: {metrics['mean_iou']:.3f}")
    print("\nPer-class IoU:")
    for idx, iou in enumerate(metrics['per_category_iou']):
        print(f"{classes[idx]}: {iou:.3f}")
    
    return metrics

def _check_activation_function(model):
    """Checks the activation function used in model's segmentation head."""
    from segmentation_models_pytorch.base.modules import Activation
    
    activation_functions = []
    for _, module in model.segmentation_head.named_children():
        if isinstance(module, Activation):
            activation_functions.append(type(module.activation).__name__)
    
    return activation_functions[-1] if activation_functions else None