import cv2 import numpy as np import matplotlib.pyplot as plt import torch def plot_predictions(model, images, masks, device, num_samples=4): """Visualize model predictions against ground truth.""" with torch.no_grad(): model.eval() predictions = model.predict(images.to(device)) fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples)) for idx in range(num_samples): # Original image img = images[idx].permute(1, 2, 0).cpu().numpy() axes[idx, 0].imshow(img) axes[idx, 0].set_title('Original Image') # Ground truth truth = masks[idx].argmax(dim=0).cpu().numpy() axes[idx, 1].imshow(truth, cmap='tab20') axes[idx, 1].set_title('Ground Truth') # Prediction pred = predictions[idx].argmax(dim=0).cpu().numpy() axes[idx, 2].imshow(pred, cmap='tab20') axes[idx, 2].set_title('Prediction') for ax in axes[idx]: ax.axis('off') plt.tight_layout() return fig def create_overlay_mask(image, mask, alpha=0.5, color_map=None): """Create transparent overlay of segmentation mask on image.""" if color_map is None: color_map = { 0: [0, 0, 0], # background 1: [255, 0, 0], # class 1 (red) 2: [0, 255, 0], # class 2 (green) 3: [0, 0, 255], # class 3 (blue) } overlay = image.copy() mask_colored = np.zeros_like(image) for label, color in color_map.items(): mask_colored[mask == label] = color cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay) return overlay def plot_training_history(history): """Plot training and validation metrics.""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) # Loss plot ax1.plot(history['train_loss'], label='Training Loss') ax1.plot(history['val_loss'], label='Validation Loss') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.set_title('Training and Validation Loss') ax1.legend() # IoU plot ax2.plot(history['mean_iou'], label='Mean IoU') for class_name, ious in history['class_ious'].items(): ax2.plot(ious, label=f'{class_name} IoU') ax2.set_xlabel('Epoch') ax2.set_ylabel('IoU') ax2.set_title('IoU Metrics') ax2.legend() plt.tight_layout() return fig def visualize_predictions_on_batch(model, batch_images, batch_size=8): """Create grid visualization for a batch of predictions.""" with torch.no_grad(): predictions = model.predict(batch_images) fig = plt.figure(figsize=(15, 5)) for idx in range(min(batch_size, len(batch_images))): plt.subplot(2, 4, idx + 1) img = batch_images[idx].permute(1, 2, 0).cpu().numpy() mask = predictions[idx].argmax(dim=0).cpu().numpy() overlay = create_overlay_mask(img, mask) plt.imshow(overlay) plt.axis('off') plt.tight_layout() return fig def save_visualization(fig, save_path): """Save visualization figure.""" fig.savefig(save_path, bbox_inches='tight', dpi=300) plt.close(fig) def generate_color_mapping(num_classes): """Generate distinct colors for segmentation classes.""" colors = [ [0, 0, 0], # Background (black) [255, 0, 0], # Red [0, 255, 0], # Green [0, 0, 255], # Blue [255, 255, 0], # Yellow [255, 0, 255], # Magenta [0, 255, 255], # Cyan [128, 0, 0], # Dark Red [0, 128, 0], # Dark Green [0, 0, 128] # Dark Blue ] return colors[:num_classes]