|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
|
|
def plot_predictions(model, images, masks, device, num_samples=4): |
|
"""Visualize model predictions against ground truth.""" |
|
with torch.no_grad(): |
|
model.eval() |
|
predictions = model.predict(images.to(device)) |
|
|
|
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples)) |
|
|
|
for idx in range(num_samples): |
|
|
|
img = images[idx].permute(1, 2, 0).cpu().numpy() |
|
axes[idx, 0].imshow(img) |
|
axes[idx, 0].set_title('Original Image') |
|
|
|
|
|
truth = masks[idx].argmax(dim=0).cpu().numpy() |
|
axes[idx, 1].imshow(truth, cmap='tab20') |
|
axes[idx, 1].set_title('Ground Truth') |
|
|
|
|
|
pred = predictions[idx].argmax(dim=0).cpu().numpy() |
|
axes[idx, 2].imshow(pred, cmap='tab20') |
|
axes[idx, 2].set_title('Prediction') |
|
|
|
for ax in axes[idx]: |
|
ax.axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def create_overlay_mask(image, mask, alpha=0.5, color_map=None): |
|
"""Create transparent overlay of segmentation mask on image.""" |
|
if color_map is None: |
|
color_map = { |
|
0: [0, 0, 0], |
|
1: [255, 0, 0], |
|
2: [0, 255, 0], |
|
3: [0, 0, 255], |
|
} |
|
|
|
overlay = image.copy() |
|
mask_colored = np.zeros_like(image) |
|
|
|
for label, color in color_map.items(): |
|
mask_colored[mask == label] = color |
|
|
|
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay) |
|
return overlay |
|
|
|
def plot_training_history(history): |
|
"""Plot training and validation metrics.""" |
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
|
|
|
ax1.plot(history['train_loss'], label='Training Loss') |
|
ax1.plot(history['val_loss'], label='Validation Loss') |
|
ax1.set_xlabel('Epoch') |
|
ax1.set_ylabel('Loss') |
|
ax1.set_title('Training and Validation Loss') |
|
ax1.legend() |
|
|
|
|
|
ax2.plot(history['mean_iou'], label='Mean IoU') |
|
for class_name, ious in history['class_ious'].items(): |
|
ax2.plot(ious, label=f'{class_name} IoU') |
|
ax2.set_xlabel('Epoch') |
|
ax2.set_ylabel('IoU') |
|
ax2.set_title('IoU Metrics') |
|
ax2.legend() |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def visualize_predictions_on_batch(model, batch_images, batch_size=8): |
|
"""Create grid visualization for a batch of predictions.""" |
|
with torch.no_grad(): |
|
predictions = model.predict(batch_images) |
|
|
|
fig = plt.figure(figsize=(15, 5)) |
|
for idx in range(min(batch_size, len(batch_images))): |
|
plt.subplot(2, 4, idx + 1) |
|
img = batch_images[idx].permute(1, 2, 0).cpu().numpy() |
|
mask = predictions[idx].argmax(dim=0).cpu().numpy() |
|
overlay = create_overlay_mask(img, mask) |
|
plt.imshow(overlay) |
|
plt.axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def save_visualization(fig, save_path): |
|
"""Save visualization figure.""" |
|
fig.savefig(save_path, bbox_inches='tight', dpi=300) |
|
plt.close(fig) |
|
|
|
def generate_color_mapping(num_classes): |
|
"""Generate distinct colors for segmentation classes.""" |
|
colors = [ |
|
[0, 0, 0], |
|
[255, 0, 0], |
|
[0, 255, 0], |
|
[0, 0, 255], |
|
[255, 255, 0], |
|
[255, 0, 255], |
|
[0, 255, 255], |
|
[128, 0, 0], |
|
[0, 128, 0], |
|
[0, 0, 128] |
|
] |
|
return colors[:num_classes] |