import os import cv2 import torch from tqdm import tqdm from torch.utils.data import DataLoader from segmentation_models_pytorch.base.modules import Activation from SemanticModel.data_loader import SegmentationDataset from SemanticModel.metrics import compute_mean_iou from SemanticModel.image_preprocessing import get_validation_augmentations def evaluate_model(model_config, data_path, image_size=None): """Evaluates model performance on a dataset.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes data_path = os.path.realpath(data_path) image_subdir = os.path.join(data_path, 'Images') mask_subdir = os.path.join(data_path, 'Masks') if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]): raise Exception("Missing required subdirectories: 'Images' and 'Masks'") if not image_size: sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0])) height, width = sample_image.shape[:2] image_size = max(height, width) evaluation_dataset = SegmentationDataset( data_path, classes=classes, augmentation=get_validation_augmentations( im_width=image_size, im_height=image_size, fixed_size=False ), preprocessing=model_config.preprocessing ) evaluation_loader = DataLoader( evaluation_dataset, batch_size=1, shuffle=False, num_workers=2 ) model = model_config.model.to(device) model.eval() requires_sigmoid = False if model_config.n_classes == 1: current_activation = _check_activation_function(model) if current_activation != 'Sigmoid': requires_sigmoid = True predictions = [] ground_truth = [] print("Evaluating model performance...") with torch.no_grad(): for images, masks in tqdm(evaluation_loader): images = images.to(device) masks = masks.to(device) outputs = model.forward(images) if model_config.n_classes > 1: predictions.extend([p.cpu().argmax(dim=0) for p in outputs]) ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks]) else: if requires_sigmoid: predictions.extend([ (torch.sigmoid(p) > 0.5).float().squeeze().cpu() for p in outputs ]) else: predictions.extend([ (p > 0.5).float().squeeze().cpu() for p in outputs ]) ground_truth.extend([gt.cpu().squeeze() for gt in masks]) metrics = compute_mean_iou( predictions, ground_truth, num_labels=len(classes), ignore_index=255 ) print("\nEvaluation Results:") print(f"Mean IoU: {metrics['mean_iou']:.3f}") print("\nPer-class IoU:") for idx, iou in enumerate(metrics['per_category_iou']): print(f"{classes[idx]}: {iou:.3f}") return metrics def _check_activation_function(model): """Checks the activation function used in model's segmentation head.""" from segmentation_models_pytorch.base.modules import Activation activation_functions = [] for _, module in model.segmentation_head.named_children(): if isinstance(module, Activation): activation_functions.append(type(module.activation).__name__) return activation_functions[-1] if activation_functions else None