File size: 3,786 Bytes
8e5d8c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import cv2
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from segmentation_models_pytorch.base.modules import Activation
from SemanticModel.data_loader import SegmentationDataset
from SemanticModel.metrics import compute_mean_iou
from SemanticModel.image_preprocessing import get_validation_augmentations
def evaluate_model(model_config, data_path, image_size=None):
"""Evaluates model performance on a dataset."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
data_path = os.path.realpath(data_path)
image_subdir = os.path.join(data_path, 'Images')
mask_subdir = os.path.join(data_path, 'Masks')
if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]):
raise Exception("Missing required subdirectories: 'Images' and 'Masks'")
if not image_size:
sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0]))
height, width = sample_image.shape[:2]
image_size = max(height, width)
evaluation_dataset = SegmentationDataset(
data_path,
classes=classes,
augmentation=get_validation_augmentations(
im_width=image_size,
im_height=image_size,
fixed_size=False
),
preprocessing=model_config.preprocessing
)
evaluation_loader = DataLoader(
evaluation_dataset,
batch_size=1,
shuffle=False,
num_workers=2
)
model = model_config.model.to(device)
model.eval()
requires_sigmoid = False
if model_config.n_classes == 1:
current_activation = _check_activation_function(model)
if current_activation != 'Sigmoid':
requires_sigmoid = True
predictions = []
ground_truth = []
print("Evaluating model performance...")
with torch.no_grad():
for images, masks in tqdm(evaluation_loader):
images = images.to(device)
masks = masks.to(device)
outputs = model.forward(images)
if model_config.n_classes > 1:
predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks])
else:
if requires_sigmoid:
predictions.extend([
(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
for p in outputs
])
else:
predictions.extend([
(p > 0.5).float().squeeze().cpu()
for p in outputs
])
ground_truth.extend([gt.cpu().squeeze() for gt in masks])
metrics = compute_mean_iou(
predictions,
ground_truth,
num_labels=len(classes),
ignore_index=255
)
print("\nEvaluation Results:")
print(f"Mean IoU: {metrics['mean_iou']:.3f}")
print("\nPer-class IoU:")
for idx, iou in enumerate(metrics['per_category_iou']):
print(f"{classes[idx]}: {iou:.3f}")
return metrics
def _check_activation_function(model):
"""Checks the activation function used in model's segmentation head."""
from segmentation_models_pytorch.base.modules import Activation
activation_functions = []
for _, module in model.segmentation_head.named_children():
if isinstance(module, Activation):
activation_functions.append(type(module.activation).__name__)
return activation_functions[-1] if activation_functions else None |