File size: 5,170 Bytes
efd5df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c895613
efd5df3
 
 
c895613
efd5df3
 
 
dc83630
c895613
efd5df3
 
 
 
 
 
 
c895613
efd5df3
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math

def initialize_model(model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model, device

def extract_patches(image, patch_size=256):
    """Extract patches from the input image."""
    width, height = image.size
    patches = []
    positions = []
    
    # Calculate number of patches in each dimension
    n_cols = math.ceil(width / patch_size)
    n_rows = math.ceil(height / patch_size)
    
    # Pad image if necessary
    padded_width = n_cols * patch_size
    padded_height = n_rows * patch_size
    padded_image = Image.new('L', (padded_width, padded_height), 0)
    padded_image.paste(image, (0, 0))
    
    # Extract patches
    for i in range(n_rows):
        for j in range(n_cols):
            left = j * patch_size
            top = i * patch_size
            right = left + patch_size
            bottom = top + patch_size
            
            patch = padded_image.crop((left, top, right, bottom))
            patches.append(patch)
            positions.append((left, top, right, bottom))
    
    return patches, positions, (padded_width, padded_height), (width, height)

def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
    """Stitch patches back together into a complete mask."""
    full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
    
    for patch, (left, top, right, bottom) in zip(patches, positions):
        full_mask[top:bottom, left:right] = patch
        
    # Crop back to original size
    full_mask = full_mask[:original_size[1], :original_size[0]]
    return full_mask

def process_patch(patch, device):
    # Convert to grayscale if it's not already
    patch_gray = patch.convert("L")
    # Convert to numpy array and normalize
    patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
    # Add batch and channel dimensions
    patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
    return patch_tensor.to(device)

def create_overlay(original_image, mask, alpha=0.5):
    colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]  # Define colors for each class
    mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for i, color in enumerate(colors):
        mask_rgb[mask == i] = color

    # Resize original image to match mask size
    original_image = original_image.resize((mask.shape[1], mask.shape[0]))
    original_array = np.array(original_image.convert("RGB"))

    # Create overlay
    overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
    return Image.fromarray(overlay)

def predict(input_image, model_choice):
    if input_image is None:
        return None, None
        
    model = models[model_choice]
    patch_size = 256
    
    # Extract patches
    patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
    
    # Process each patch
    predicted_patches = []
    for patch in patches:
        # Process patch
        patch_tensor = process_patch(patch, device)
        
        # Perform inference
        with torch.no_grad():
            output = model(patch_tensor)
        
        # Get prediction mask for patch
        patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
        predicted_patches.append(patch_mask)
    
    # Stitch patches back together
    full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
    
    # Create mask image
    mask_image = Image.fromarray((full_mask * 63).astype(np.uint8))  # Scale for better visibility
    
    # Create overlay image
    overlay_image = create_overlay(input_image, full_mask)
    
    return mask_image, overlay_image

# Initialize model (do this outside the inference function for better performance)
w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"

w_noise_model, device = initialize_model(w_noise_model_path)
wo_noise_model, device = initialize_model(wo_noise_model_path)
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)

models = {
    "Without Noise": wo_noise_model,
    "With Noise": w_noise_model,
    "With Noise V2": w_noise_model_v2
}

# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil"),
        gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"),
    ],
    outputs=[
        gr.Image(type="pil", label="Segmentation Mask"),
        gr.Image(type="pil", label="Overlay"),
    ],
    title="MoS2 Image Segmentation",
    description="Upload an image to get the segmentation mask and overlay visualization.",
    examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]],
)

# Launch the interface
iface.launch(share=True)