File size: 3,722 Bytes
8e5d8c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
def plot_predictions(model, images, masks, device, num_samples=4):
"""Visualize model predictions against ground truth."""
with torch.no_grad():
model.eval()
predictions = model.predict(images.to(device))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
for idx in range(num_samples):
# Original image
img = images[idx].permute(1, 2, 0).cpu().numpy()
axes[idx, 0].imshow(img)
axes[idx, 0].set_title('Original Image')
# Ground truth
truth = masks[idx].argmax(dim=0).cpu().numpy()
axes[idx, 1].imshow(truth, cmap='tab20')
axes[idx, 1].set_title('Ground Truth')
# Prediction
pred = predictions[idx].argmax(dim=0).cpu().numpy()
axes[idx, 2].imshow(pred, cmap='tab20')
axes[idx, 2].set_title('Prediction')
for ax in axes[idx]:
ax.axis('off')
plt.tight_layout()
return fig
def create_overlay_mask(image, mask, alpha=0.5, color_map=None):
"""Create transparent overlay of segmentation mask on image."""
if color_map is None:
color_map = {
0: [0, 0, 0], # background
1: [255, 0, 0], # class 1 (red)
2: [0, 255, 0], # class 2 (green)
3: [0, 0, 255], # class 3 (blue)
}
overlay = image.copy()
mask_colored = np.zeros_like(image)
for label, color in color_map.items():
mask_colored[mask == label] = color
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
return overlay
def plot_training_history(history):
"""Plot training and validation metrics."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Loss plot
ax1.plot(history['train_loss'], label='Training Loss')
ax1.plot(history['val_loss'], label='Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
# IoU plot
ax2.plot(history['mean_iou'], label='Mean IoU')
for class_name, ious in history['class_ious'].items():
ax2.plot(ious, label=f'{class_name} IoU')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('IoU')
ax2.set_title('IoU Metrics')
ax2.legend()
plt.tight_layout()
return fig
def visualize_predictions_on_batch(model, batch_images, batch_size=8):
"""Create grid visualization for a batch of predictions."""
with torch.no_grad():
predictions = model.predict(batch_images)
fig = plt.figure(figsize=(15, 5))
for idx in range(min(batch_size, len(batch_images))):
plt.subplot(2, 4, idx + 1)
img = batch_images[idx].permute(1, 2, 0).cpu().numpy()
mask = predictions[idx].argmax(dim=0).cpu().numpy()
overlay = create_overlay_mask(img, mask)
plt.imshow(overlay)
plt.axis('off')
plt.tight_layout()
return fig
def save_visualization(fig, save_path):
"""Save visualization figure."""
fig.savefig(save_path, bbox_inches='tight', dpi=300)
plt.close(fig)
def generate_color_mapping(num_classes):
"""Generate distinct colors for segmentation classes."""
colors = [
[0, 0, 0], # Background (black)
[255, 0, 0], # Red
[0, 255, 0], # Green
[0, 0, 255], # Blue
[255, 255, 0], # Yellow
[255, 0, 255], # Magenta
[0, 255, 255], # Cyan
[128, 0, 0], # Dark Red
[0, 128, 0], # Dark Green
[0, 0, 128] # Dark Blue
]
return colors[:num_classes] |