File size: 3,722 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

def plot_predictions(model, images, masks, device, num_samples=4):
    """Visualize model predictions against ground truth."""
    with torch.no_grad():
        model.eval()
        predictions = model.predict(images.to(device))
        
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    
    for idx in range(num_samples):
        # Original image
        img = images[idx].permute(1, 2, 0).cpu().numpy()
        axes[idx, 0].imshow(img)
        axes[idx, 0].set_title('Original Image')
        
        # Ground truth
        truth = masks[idx].argmax(dim=0).cpu().numpy()
        axes[idx, 1].imshow(truth, cmap='tab20')
        axes[idx, 1].set_title('Ground Truth')
        
        # Prediction
        pred = predictions[idx].argmax(dim=0).cpu().numpy()
        axes[idx, 2].imshow(pred, cmap='tab20')
        axes[idx, 2].set_title('Prediction')
        
        for ax in axes[idx]:
            ax.axis('off')
    
    plt.tight_layout()
    return fig

def create_overlay_mask(image, mask, alpha=0.5, color_map=None):
    """Create transparent overlay of segmentation mask on image."""
    if color_map is None:
        color_map = {
            0: [0, 0, 0],      # background
            1: [255, 0, 0],    # class 1 (red)
            2: [0, 255, 0],    # class 2 (green)
            3: [0, 0, 255],    # class 3 (blue)
        }
    
    overlay = image.copy()
    mask_colored = np.zeros_like(image)
    
    for label, color in color_map.items():
        mask_colored[mask == label] = color
    
    cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
    return overlay

def plot_training_history(history):
    """Plot training and validation metrics."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss plot
    ax1.plot(history['train_loss'], label='Training Loss')
    ax1.plot(history['val_loss'], label='Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    
    # IoU plot
    ax2.plot(history['mean_iou'], label='Mean IoU')
    for class_name, ious in history['class_ious'].items():
        ax2.plot(ious, label=f'{class_name} IoU')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('IoU')
    ax2.set_title('IoU Metrics')
    ax2.legend()
    
    plt.tight_layout()
    return fig

def visualize_predictions_on_batch(model, batch_images, batch_size=8):
    """Create grid visualization for a batch of predictions."""
    with torch.no_grad():
        predictions = model.predict(batch_images)
    
    fig = plt.figure(figsize=(15, 5))
    for idx in range(min(batch_size, len(batch_images))):
        plt.subplot(2, 4, idx + 1)
        img = batch_images[idx].permute(1, 2, 0).cpu().numpy()
        mask = predictions[idx].argmax(dim=0).cpu().numpy()
        overlay = create_overlay_mask(img, mask)
        plt.imshow(overlay)
        plt.axis('off')
    
    plt.tight_layout()
    return fig

def save_visualization(fig, save_path):
    """Save visualization figure."""
    fig.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close(fig)
    
def generate_color_mapping(num_classes):
    """Generate distinct colors for segmentation classes."""
    colors = [
        [0, 0, 0],       # Background (black)
        [255, 0, 0],     # Red
        [0, 255, 0],     # Green 
        [0, 0, 255],     # Blue
        [255, 255, 0],   # Yellow
        [255, 0, 255],   # Magenta
        [0, 255, 255],   # Cyan
        [128, 0, 0],     # Dark Red
        [0, 128, 0],     # Dark Green
        [0, 0, 128]      # Dark Blue
    ]
    return colors[:num_classes]