umairahmad1789 commited on
Commit
efd5df3
·
verified ·
1 Parent(s): a6fe5c5

initial commit

Browse files
examples/image_000003.png ADDED
examples/image_000004.png ADDED
examples/image_000005.png ADDED
examples/image_000006.png ADDED
examples/image_000007.png ADDED
examples/image_000029.png ADDED
examples/image_000030.png ADDED
examples/image_000031.png ADDED
examples/image_000032.png ADDED
examples/image_000033.png ADDED
inference_patches.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from unet import EnhancedUNet
4
+ import numpy as np
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ import math
9
+
10
+ def initialize_model(model_path):
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
13
+ model.load_state_dict(torch.load(model_path, map_location=device))
14
+ model.eval()
15
+ return model, device
16
+
17
+ def extract_patches(image, patch_size=256):
18
+ """Extract patches from the input image."""
19
+ width, height = image.size
20
+ patches = []
21
+ positions = []
22
+
23
+ # Calculate number of patches in each dimension
24
+ n_cols = math.ceil(width / patch_size)
25
+ n_rows = math.ceil(height / patch_size)
26
+
27
+ # Pad image if necessary
28
+ padded_width = n_cols * patch_size
29
+ padded_height = n_rows * patch_size
30
+ padded_image = Image.new('L', (padded_width, padded_height), 0)
31
+ padded_image.paste(image, (0, 0))
32
+
33
+ # Extract patches
34
+ for i in range(n_rows):
35
+ for j in range(n_cols):
36
+ left = j * patch_size
37
+ top = i * patch_size
38
+ right = left + patch_size
39
+ bottom = top + patch_size
40
+
41
+ patch = padded_image.crop((left, top, right, bottom))
42
+ patches.append(patch)
43
+ positions.append((left, top, right, bottom))
44
+
45
+ return patches, positions, (padded_width, padded_height), (width, height)
46
+
47
+ def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
48
+ """Stitch patches back together into a complete mask."""
49
+ full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
50
+
51
+ for patch, (left, top, right, bottom) in zip(patches, positions):
52
+ full_mask[top:bottom, left:right] = patch
53
+
54
+ # Crop back to original size
55
+ full_mask = full_mask[:original_size[1], :original_size[0]]
56
+ return full_mask
57
+
58
+ def process_patch(patch, device):
59
+ # Convert to grayscale if it's not already
60
+ patch_gray = patch.convert("L")
61
+ # Convert to numpy array and normalize
62
+ patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
63
+ # Add batch and channel dimensions
64
+ patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
65
+ return patch_tensor.to(device)
66
+
67
+ def create_overlay(original_image, mask, alpha=0.5):
68
+ colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class
69
+ mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
70
+ for i, color in enumerate(colors):
71
+ mask_rgb[mask == i] = color
72
+
73
+ # Resize original image to match mask size
74
+ original_image = original_image.resize((mask.shape[1], mask.shape[0]))
75
+ original_array = np.array(original_image.convert("RGB"))
76
+
77
+ # Create overlay
78
+ overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
79
+ return Image.fromarray(overlay)
80
+
81
+ def predict(input_image, model_choice):
82
+ if input_image is None:
83
+ return None, None
84
+
85
+ model = models[model_choice]
86
+ patch_size = 256
87
+
88
+ # Extract patches
89
+ patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
90
+
91
+ # Process each patch
92
+ predicted_patches = []
93
+ for patch in patches:
94
+ # Process patch
95
+ patch_tensor = process_patch(patch, device)
96
+
97
+ # Perform inference
98
+ with torch.no_grad():
99
+ output = model(patch_tensor)
100
+
101
+ # Get prediction mask for patch
102
+ patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
103
+ predicted_patches.append(patch_mask)
104
+
105
+ # Stitch patches back together
106
+ full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
107
+
108
+ # Create mask image
109
+ mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility
110
+
111
+ # Create overlay image
112
+ overlay_image = create_overlay(input_image, full_mask)
113
+
114
+ return mask_image, overlay_image
115
+
116
+ # Initialize model (do this outside the inference function for better performance)
117
+ w_noise_model_path = "./models/best_model_w_noise.pth"
118
+ wo_noise_model_path = "./models/best_model_wo_noise.pth"
119
+
120
+ w_noise_model, device = initialize_model(w_noise_model_path)
121
+ wo_noise_model, device = initialize_model(wo_noise_model_path)
122
+
123
+ models = {
124
+ "Without Noise": wo_noise_model,
125
+ "With Noise": w_noise_model
126
+ }
127
+
128
+ # Create Gradio interface
129
+ iface = gr.Interface(
130
+ fn=predict,
131
+ inputs=[
132
+ gr.Image(type="pil"),
133
+ gr.Dropdown(choices=["Without Noise", "With Noise"], value="With Noise"),
134
+ ],
135
+ outputs=[
136
+ gr.Image(type="pil", label="Segmentation Mask"),
137
+ gr.Image(type="pil", label="Overlay"),
138
+ ],
139
+ title="MoS2 Image Segmentation",
140
+ description="Upload an image to get the segmentation mask and overlay visualization.",
141
+ examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]],
142
+ )
143
+
144
+ # Launch the interface
145
+ iface.launch(share=True)
models/best_model_w_noise.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47f9134dd87fa34d7491ee6a95838aace97c1900f261db729c9eb1e06cd16333
3
+ size 206643490
models/best_model_wo_noise.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33f36c8b756e81578ffca593594057b7ab0dfd335a4c5e10dd398bd9bf9b1d67
3
+ size 206643490
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ gradio
3
+ pillow
train.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from tqdm import tqdm
11
+ import random
12
+ from scipy.ndimage import gaussian_filter, map_coordinates # Add this line
13
+ import PIL
14
+
15
+ class ResidualConvBlock(nn.Module):
16
+ def __init__(self, in_channels, out_channels):
17
+ super(ResidualConvBlock, self).__init__()
18
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
19
+ self.in1 = nn.InstanceNorm2d(out_channels)
20
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
21
+ self.in2 = nn.InstanceNorm2d(out_channels)
22
+ self.relu = nn.LeakyReLU(inplace=True)
23
+ self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
24
+
25
+ def forward(self, x):
26
+ residual = x
27
+ out = self.relu(self.in1(self.conv1(x)))
28
+ out = self.in2(self.conv2(out))
29
+ if self.downsample:
30
+ residual = self.downsample(x)
31
+ out += residual
32
+ return self.relu(out)
33
+
34
+ class AttentionGate(nn.Module):
35
+ def __init__(self, F_g, F_l, F_int):
36
+ super(AttentionGate, self).__init__()
37
+ self.W_g = nn.Sequential(
38
+ nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
39
+ nn.InstanceNorm2d(F_int)
40
+ )
41
+ self.W_x = nn.Sequential(
42
+ nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
43
+ nn.InstanceNorm2d(F_int)
44
+ )
45
+ self.psi = nn.Sequential(
46
+ nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
47
+ nn.InstanceNorm2d(1),
48
+ nn.Sigmoid()
49
+ )
50
+ self.relu = nn.LeakyReLU(inplace=True)
51
+
52
+ def forward(self, g, x):
53
+ g1 = self.W_g(g)
54
+ x1 = self.W_x(x)
55
+ psi = self.relu(g1 + x1)
56
+ psi = self.psi(psi)
57
+ return x * psi
58
+
59
+ class EnhancedUNet(nn.Module):
60
+ def __init__(self, n_channels, n_classes):
61
+ super(EnhancedUNet, self).__init__()
62
+ self.n_channels = n_channels
63
+ self.n_classes = n_classes
64
+
65
+ self.inc = ResidualConvBlock(n_channels, 64)
66
+ self.down1 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(64, 128))
67
+ self.down2 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(128, 256))
68
+ self.down3 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(256, 512))
69
+ self.down4 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(512, 1024))
70
+
71
+ self.dilation = nn.Sequential(
72
+ nn.Conv2d(1024, 1024, kernel_size=3, padding=2, dilation=2),
73
+ nn.InstanceNorm2d(1024),
74
+ nn.LeakyReLU(inplace=True),
75
+ nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4),
76
+ nn.InstanceNorm2d(1024),
77
+ nn.LeakyReLU(inplace=True)
78
+ )
79
+
80
+ self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
81
+ self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256)
82
+ self.up_conv4 = ResidualConvBlock(1024, 512)
83
+
84
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
85
+ self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128)
86
+ self.up_conv3 = ResidualConvBlock(512, 256)
87
+
88
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
89
+ self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64)
90
+ self.up_conv2 = ResidualConvBlock(256, 128)
91
+
92
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
93
+ self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32)
94
+ self.up_conv1 = ResidualConvBlock(128, 64)
95
+
96
+ self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
97
+
98
+ self.dropout = nn.Dropout(0.5)
99
+
100
+ def forward(self, x):
101
+ x1 = self.inc(x)
102
+ x2 = self.down1(x1)
103
+ x2 = self.dropout(x2)
104
+ x3 = self.down2(x2)
105
+ x3 = self.dropout(x3)
106
+ x4 = self.down3(x3)
107
+ x4 = self.dropout(x4)
108
+ x5 = self.down4(x4)
109
+
110
+ x5 = self.dilation(x5)
111
+ x5 = self.dropout(x5)
112
+
113
+ x = self.up4(x5)
114
+ x4 = self.att4(g=x, x=x4)
115
+ x = torch.cat([x4, x], dim=1)
116
+ x = self.up_conv4(x)
117
+ x = self.dropout(x)
118
+
119
+ x = self.up3(x)
120
+ x3 = self.att3(g=x, x=x3)
121
+ x = torch.cat([x3, x], dim=1)
122
+ x = self.up_conv3(x)
123
+ x = self.dropout(x)
124
+
125
+ x = self.up2(x)
126
+ x2 = self.att2(g=x, x=x2)
127
+ x = torch.cat([x2, x], dim=1)
128
+ x = self.up_conv2(x)
129
+ x = self.dropout(x)
130
+
131
+ x = self.up1(x)
132
+ x1 = self.att1(g=x, x=x1)
133
+ x = torch.cat([x1, x], dim=1)
134
+ x = self.up_conv1(x)
135
+
136
+ logits = self.outc(x)
137
+ return logits
138
+
139
+ class MoS2Dataset(Dataset):
140
+ def __init__(self, root_dir, transform=None):
141
+ self.root_dir = root_dir
142
+ self.transform = transform
143
+ self.images_dir = os.path.join(root_dir, 'images')
144
+ self.labels_dir = os.path.join(root_dir, 'labels')
145
+ self.image_files = []
146
+ for f in sorted(os.listdir(self.images_dir)):
147
+ if f.endswith('.png'):
148
+ try:
149
+ Image.open(os.path.join(self.images_dir, f)).verify()
150
+ self.image_files.append(f)
151
+ except:
152
+ print(f"Skipping unreadable image: {f}")
153
+
154
+ def __len__(self):
155
+ return len(self.image_files)
156
+
157
+ def __getitem__(self, idx):
158
+ img_name = self.image_files[idx]
159
+ img_path = os.path.join(self.images_dir, img_name)
160
+ if not os.path.exists(img_path):
161
+ print(f"Image file does not exist: {img_path}")
162
+ return None, None
163
+ label_name = f"image_{img_name.split('_')[1].replace('.png', '.npy')}"
164
+ label_path = os.path.join(self.labels_dir, label_name)
165
+
166
+ try:
167
+ image = np.array(Image.open(img_path).convert('L'), dtype=np.float32) / 255.0
168
+ label = np.load(label_path).astype(np.int64)
169
+ except (PIL.UnidentifiedImageError, FileNotFoundError, IOError) as e:
170
+ print(f"Error loading image {img_path}: {str(e)}")
171
+ return None, None # Or handle this case appropriately
172
+
173
+ if self.transform:
174
+ image, label = self.transform(image, label)
175
+
176
+ image = torch.from_numpy(image).float().unsqueeze(0)
177
+ label = torch.from_numpy(label).long()
178
+
179
+ return image, label
180
+
181
+ class AugmentationTransform:
182
+ def __init__(self):
183
+ self.aug_functions = [
184
+ self.random_brightness_contrast,
185
+ self.random_gamma,
186
+ self.random_noise,
187
+ self.random_elastic_deform
188
+ ]
189
+
190
+ def __call__(self, image, label):
191
+ for aug_func in self.aug_functions:
192
+ if random.random() < 0.5: # 50% chance to apply each augmentation
193
+ image, label = aug_func(image, label)
194
+ return image.astype(np.float32), label # Ensure float32
195
+
196
+
197
+ def random_brightness_contrast(self, image, label):
198
+ brightness = random.uniform(0.7, 1.3)
199
+ contrast = random.uniform(0.7, 1.3)
200
+ image = np.clip(brightness * image + contrast * (image - 0.5) + 0.5, 0, 1)
201
+ return image, label
202
+
203
+ def random_gamma(self, image, label):
204
+ gamma = random.uniform(0.7, 1.3)
205
+ image = np.power(image, gamma)
206
+ return image, label
207
+
208
+ def random_noise(self, image, label):
209
+ noise = np.random.normal(0, 0.05, image.shape)
210
+ image = np.clip(image + noise, 0, 1)
211
+ return image, label
212
+
213
+ def random_elastic_deform(self, image, label):
214
+ alpha = random.uniform(10, 20)
215
+ sigma = random.uniform(3, 5)
216
+ shape = image.shape
217
+ dx = np.random.rand(*shape) * 2 - 1
218
+ dy = np.random.rand(*shape) * 2 - 1
219
+ dx = gaussian_filter(dx, sigma, mode="constant", cval=0) * alpha
220
+ dy = gaussian_filter(dy, sigma, mode="constant", cval=0) * alpha
221
+ x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
222
+ indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
223
+ image = map_coordinates(image, indices, order=1).reshape(shape)
224
+ label = map_coordinates(label, indices, order=0).reshape(shape)
225
+ return image, label
226
+
227
+ def focal_loss(output, target, alpha=0.25, gamma=2):
228
+ ce_loss = nn.CrossEntropyLoss(reduction='none')(output, target)
229
+ pt = torch.exp(-ce_loss)
230
+ focal_loss = alpha * (1-pt)**gamma * ce_loss
231
+ return focal_loss.mean()
232
+
233
+ def dice_loss(output, target, smooth=1e-5):
234
+ output = torch.softmax(output, dim=1)
235
+ num_classes = output.shape[1]
236
+ dice_sum = 0
237
+ for c in range(num_classes):
238
+ pred_class = output[:, c, :, :]
239
+ target_class = (target == c).float()
240
+ intersection = (pred_class * target_class).sum()
241
+ union = pred_class.sum() + target_class.sum()
242
+ dice = (2. * intersection + smooth) / (union + smooth)
243
+ dice_sum += dice
244
+ return 1 - dice_sum / num_classes
245
+
246
+ def combined_loss(output, target):
247
+ fl = focal_loss(output, target)
248
+ dl = dice_loss(output, target)
249
+ return 0.5 * fl + 0.5 * dl
250
+
251
+ def iou_score(output, target):
252
+ smooth = 1e-5
253
+ output = torch.argmax(output, dim=1)
254
+ intersection = (output & target).float().sum((1, 2))
255
+ union = (output | target).float().sum((1, 2))
256
+ iou = (intersection + smooth) / (union + smooth)
257
+ return iou.mean()
258
+
259
+ def pixel_accuracy(output, target):
260
+ output = torch.argmax(output, dim=1)
261
+ correct = torch.eq(output, target).int()
262
+ accuracy = float(correct.sum()) / float(correct.numel())
263
+ return accuracy
264
+
265
+ def train_one_epoch(model, dataloader, optimizer, criterion, device):
266
+ model.train()
267
+ total_loss = 0
268
+ total_iou = 0
269
+ total_accuracy = 0
270
+
271
+ pbar = tqdm(dataloader, desc='Training')
272
+ for images, labels in pbar:
273
+ images, labels = images.to(device), labels.to(device)
274
+
275
+ optimizer.zero_grad()
276
+ outputs = model(images)
277
+ loss = criterion(outputs, labels)
278
+ loss.backward()
279
+ optimizer.step()
280
+
281
+ total_loss += loss.item()
282
+ total_iou += iou_score(outputs, labels)
283
+ total_accuracy += pixel_accuracy(outputs, labels)
284
+
285
+ pbar.set_postfix({'Loss': total_loss / (pbar.n + 1),
286
+ 'IoU': total_iou / (pbar.n + 1),
287
+ 'Accuracy': total_accuracy / (pbar.n + 1)})
288
+
289
+ return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader)
290
+
291
+ def validate(model, dataloader, criterion, device):
292
+ model.eval()
293
+ total_loss = 0
294
+ total_iou = 0
295
+ total_accuracy = 0
296
+
297
+ with torch.no_grad():
298
+ pbar = tqdm(dataloader, desc='Validation')
299
+ for images, labels in pbar:
300
+ images, labels = images.to(device), labels.to(device)
301
+
302
+ outputs = model(images)
303
+ loss = criterion(outputs, labels)
304
+
305
+ total_loss += loss.item()
306
+ total_iou += iou_score(outputs, labels)
307
+ total_accuracy += pixel_accuracy(outputs, labels)
308
+
309
+ pbar.set_postfix({'Loss': total_loss / (pbar.n + 1),
310
+ 'IoU': total_iou / (pbar.n + 1),
311
+ 'Accuracy': total_accuracy / (pbar.n + 1)})
312
+
313
+ return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader)
314
+
315
+ def main():
316
+ # Hyperparameters
317
+ num_classes = 4
318
+ batch_size = 64
319
+ num_epochs = 100
320
+ learning_rate = 1e-4
321
+ weight_decay = 1e-5
322
+
323
+ # Device configuration
324
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
325
+ print(f"Using device: {device}")
326
+
327
+ # Create datasets and data loaders
328
+ transform = AugmentationTransform()
329
+ # dataset = MoS2Dataset('MoS2_dataset_advanced_v2', transform=transform)
330
+ dataset = MoS2Dataset('dataset_with_noise_npy')
331
+
332
+ train_size = int(0.8 * len(dataset))
333
+ val_size = len(dataset) - train_size
334
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
335
+
336
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
337
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
338
+
339
+ # Create model
340
+ model = EnhancedUNet(n_channels=1, n_classes=num_classes).to(device)
341
+
342
+ # Loss and optimizer
343
+ criterion = combined_loss
344
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
345
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True)
346
+
347
+ # Create directory for saving models and visualizations
348
+ save_dir = 'enhanced_training_results'
349
+ os.makedirs(save_dir, exist_ok=True)
350
+
351
+ # Training loop
352
+ best_val_iou = 0.0
353
+ for epoch in range(1, num_epochs + 1):
354
+ print(f"Epoch {epoch}/{num_epochs}")
355
+
356
+ train_loss, train_iou, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device)
357
+ val_loss, val_iou, val_accuracy = validate(model, val_loader, criterion, device)
358
+
359
+ print(f"Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Accuracy: {train_accuracy:.4f}")
360
+ print(f"Val - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Accuracy: {val_accuracy:.4f}")
361
+
362
+ scheduler.step(val_iou)
363
+
364
+ if val_iou > best_val_iou:
365
+ best_val_iou = val_iou
366
+ torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
367
+ print(f"New best model saved with IoU: {best_val_iou:.4f}")
368
+
369
+ # Save checkpoint
370
+ torch.save({
371
+ 'epoch': epoch,
372
+ 'model_state_dict': model.state_dict(),
373
+ 'optimizer_state_dict': optimizer.state_dict(),
374
+ 'scheduler_state_dict': scheduler.state_dict(),
375
+ 'best_val_iou': best_val_iou,
376
+ }, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth'))
377
+
378
+ # Visualize predictions every 5 epochs
379
+
380
+ visualize_prediction(model, val_loader, device, epoch, save_dir)
381
+
382
+ print("Training completed!")
383
+
384
+ def visualize_prediction(model, val_loader, device, epoch, save_dir):
385
+ model.eval()
386
+ images, labels = next(iter(val_loader))
387
+ images, labels = images.to(device), labels.to(device)
388
+ with torch.no_grad():
389
+ outputs = model(images)
390
+
391
+ images = images.cpu().numpy()
392
+ labels = labels.cpu().numpy()
393
+ predictions = torch.argmax(outputs, dim=1).cpu().numpy()
394
+
395
+ fig, axs = plt.subplots(2, 3, figsize=(15, 10))
396
+ axs[0, 0].imshow(images[0, 0], cmap='gray')
397
+ axs[0, 0].set_title('Input Image')
398
+ axs[0, 1].imshow(labels[0], cmap='viridis')
399
+ axs[0, 1].set_title('True Label')
400
+ axs[0, 2].imshow(predictions[0], cmap='viridis')
401
+ axs[0, 2].set_title('Prediction')
402
+ axs[1, 0].imshow(images[1, 0], cmap='gray')
403
+ axs[1, 1].imshow(labels[1], cmap='viridis')
404
+ axs[1, 2].imshow(predictions[1], cmap='viridis')
405
+ plt.tight_layout()
406
+ plt.savefig(os.path.join(save_dir, f'prediction_epoch_{epoch}.png'))
407
+ plt.close()
408
+
409
+ if __name__ == "__main__":
410
+ main()
unet.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ResidualConvBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(ResidualConvBlock, self).__init__()
8
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
9
+ self.in1 = nn.InstanceNorm2d(out_channels)
10
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
11
+ self.in2 = nn.InstanceNorm2d(out_channels)
12
+ self.relu = nn.LeakyReLU(inplace=True)
13
+ self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
14
+
15
+ def forward(self, x):
16
+ residual = x
17
+ out = self.relu(self.in1(self.conv1(x)))
18
+ out = self.in2(self.conv2(out))
19
+ if self.downsample:
20
+ residual = self.downsample(x)
21
+ out += residual
22
+ return self.relu(out)
23
+
24
+ class AttentionGate(nn.Module):
25
+ def __init__(self, F_g, F_l, F_int):
26
+ super(AttentionGate, self).__init__()
27
+ self.W_g = nn.Sequential(
28
+ nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
29
+ nn.InstanceNorm2d(F_int)
30
+ )
31
+ self.W_x = nn.Sequential(
32
+ nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
33
+ nn.InstanceNorm2d(F_int)
34
+ )
35
+ self.psi = nn.Sequential(
36
+ nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
37
+ nn.InstanceNorm2d(1),
38
+ nn.Sigmoid()
39
+ )
40
+ self.relu = nn.LeakyReLU(inplace=True)
41
+
42
+ def forward(self, g, x):
43
+ g1 = self.W_g(g)
44
+ x1 = self.W_x(x)
45
+ psi = self.relu(g1 + x1)
46
+ psi = self.psi(psi)
47
+ return x * psi
48
+
49
+ class EnhancedUNet(nn.Module):
50
+ def __init__(self, n_channels, n_classes):
51
+ super(EnhancedUNet, self).__init__()
52
+ self.n_channels = n_channels
53
+ self.n_classes = n_classes
54
+
55
+ self.inc = ResidualConvBlock(n_channels, 64)
56
+ self.down1 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(64, 128))
57
+ self.down2 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(128, 256))
58
+ self.down3 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(256, 512))
59
+ self.down4 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(512, 1024))
60
+
61
+ self.dilation = nn.Sequential(
62
+ nn.Conv2d(1024, 1024, kernel_size=3, padding=2, dilation=2),
63
+ nn.InstanceNorm2d(1024),
64
+ nn.LeakyReLU(inplace=True),
65
+ nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4),
66
+ nn.InstanceNorm2d(1024),
67
+ nn.LeakyReLU(inplace=True)
68
+ )
69
+
70
+ self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
71
+ self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256)
72
+ self.up_conv4 = ResidualConvBlock(1024, 512)
73
+
74
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
75
+ self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128)
76
+ self.up_conv3 = ResidualConvBlock(512, 256)
77
+
78
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
79
+ self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64)
80
+ self.up_conv2 = ResidualConvBlock(256, 128)
81
+
82
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
83
+ self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32)
84
+ self.up_conv1 = ResidualConvBlock(128, 64)
85
+
86
+ self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
87
+
88
+ self.dropout = nn.Dropout(0.5)
89
+
90
+ def forward(self, x):
91
+ x1 = self.inc(x)
92
+ x2 = self.down1(x1)
93
+ x2 = self.dropout(x2)
94
+ x3 = self.down2(x2)
95
+ x3 = self.dropout(x3)
96
+ x4 = self.down3(x3)
97
+ x4 = self.dropout(x4)
98
+ x5 = self.down4(x4)
99
+
100
+ x5 = self.dilation(x5)
101
+ x5 = self.dropout(x5)
102
+
103
+ x = self.up4(x5)
104
+ x4 = self.att4(g=x, x=x4)
105
+ x = torch.cat([x4, x], dim=1)
106
+ x = self.up_conv4(x)
107
+ x = self.dropout(x)
108
+
109
+ x = self.up3(x)
110
+ x3 = self.att3(g=x, x=x3)
111
+ x = torch.cat([x3, x], dim=1)
112
+ x = self.up_conv3(x)
113
+ x = self.dropout(x)
114
+
115
+ x = self.up2(x)
116
+ x2 = self.att2(g=x, x=x2)
117
+ x = torch.cat([x2, x], dim=1)
118
+ x = self.up_conv2(x)
119
+ x = self.dropout(x)
120
+
121
+ x = self.up1(x)
122
+ x1 = self.att1(g=x, x=x1)
123
+ x = torch.cat([x1, x], dim=1)
124
+ x = self.up_conv1(x)
125
+
126
+ logits = self.outc(x)
127
+ return logits