thov commited on
Commit
e6f4cd4
1 Parent(s): 1df7042

add training

Browse files
Files changed (4) hide show
  1. UNET_perso.py +75 -0
  2. main.py +150 -0
  3. src/medicalDataLoader.py +3 -1
  4. src/utils.py +294 -0
UNET_perso.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms.functional as TF
4
+
5
+
6
+ #Aladdinpersson/Machine-Learning-Collection GIT
7
+
8
+ """
9
+ Defining a UNet block
10
+ in_channels: image dimension
11
+ """
12
+ class DoubleConv(nn.Module):
13
+ def __init__(self, in_channels, out_channels):
14
+ super(DoubleConv, self).__init__()
15
+ self.conv = nn.Sequential(
16
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
17
+ nn.BatchNorm2d(out_channels),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True),
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.conv(x)
26
+
27
+ class UNET(nn.Module):
28
+ def __init__(
29
+ self, in_channels=3, out_channels=4, features=[64, 128, 256, 512]):
30
+ super(UNET, self).__init__()
31
+ self.ups = nn.ModuleList()
32
+ self.downs = nn.ModuleList()
33
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
34
+
35
+ # Down part of UNET
36
+ for feature in features:
37
+ self.downs.append(DoubleConv(in_channels, feature))
38
+ in_channels = feature
39
+
40
+ # Up part of UNET
41
+ for feature in reversed(features):
42
+ self.ups.append(
43
+ nn.ConvTranspose2d(
44
+ feature*2, feature, kernel_size=2, stride=2,
45
+ )
46
+ )
47
+ self.ups.append(DoubleConv(feature*2, feature))
48
+
49
+ #layer between down part and up part
50
+ self.bottleneck = DoubleConv(features[-1], features[-1]*2)
51
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
52
+
53
+ def forward(self, x):
54
+ skip_connections = []
55
+
56
+ for down in self.downs:
57
+ x = down(x)
58
+ skip_connections.append(x)
59
+ x = self.pool(x)
60
+
61
+ x = self.bottleneck(x)
62
+ skip_connections = skip_connections[::-1]
63
+
64
+ for idx in range(0, len(self.ups), 2):
65
+ x = self.ups[idx](x)
66
+ skip_connection = skip_connections[idx//2]
67
+
68
+ #Double check if input size is not divisible by 2, we need to be sure that the two shapes are similar
69
+ if x.shape != skip_connection.shape:
70
+ x = TF.resize(x, size=skip_connection.shape[2:])
71
+
72
+ concat_skip = torch.cat((skip_connection, x), dim=1)
73
+ x = self.ups[idx+1](concat_skip)
74
+
75
+ return self.final_conv(x)
main.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms
11
+ from torchmetrics.functional import dice, jaccard_index, accuracy
12
+
13
+ from segmentation_models_pytorch.losses import DiceLoss, TverskyLoss, FocalLoss
14
+
15
+ from src.medicalDataLoader import MedicalImageDataset
16
+ from src.utils import getTargetSegmentation, plot_img
17
+ from UNET_perso import UNET
18
+
19
+ ## Parameters & Hyperparameters ##
20
+ EPOCHS = 2
21
+ BATCH_SIZE_TRAIN = 8
22
+ BATCH_SIZE_VAL = 8
23
+ LR = 1e-3
24
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+ torch.cuda.empty_cache()
26
+
27
+ ## Model ##
28
+ model = UNET(in_channels=1, out_channels=4).to(DEVICE)
29
+
30
+ ## Loss ##
31
+ lossCE = nn.CrossEntropyLoss().to(DEVICE)
32
+ lossDice = DiceLoss(mode='multiclass').to(DEVICE)
33
+
34
+ ## optimizer ##
35
+ #optimizer = torch.optim.Adam(model.parameters(), lr=LR)
36
+ optimizer = torch.optim.NAdam(model.parameters(), lr=LR)
37
+
38
+
39
+ transform = transforms.Compose([transforms.ToTensor()])
40
+ ROOT_DIR = './Data'
41
+ train_set = MedicalImageDataset('train',
42
+ ROOT_DIR,
43
+ transform=transform,
44
+ mask_transform=transform,
45
+ augment=True,
46
+ equalize=False)
47
+
48
+ train_loader = DataLoader(train_set,
49
+ batch_size=BATCH_SIZE_TRAIN,
50
+ shuffle=True)
51
+
52
+ val_set = MedicalImageDataset('val',
53
+ ROOT_DIR,
54
+ transform=transform,
55
+ mask_transform=transform,
56
+ equalize=False)
57
+
58
+ val_loader = DataLoader(val_set,
59
+ batch_size=BATCH_SIZE_VAL,
60
+ shuffle=False)
61
+
62
+ test_set = MedicalImageDataset('test',
63
+ ROOT_DIR,
64
+ transform=transform,
65
+ mask_transform=transform,
66
+ equalize=False)
67
+
68
+ test_loader = DataLoader(test_set,
69
+ batch_size=BATCH_SIZE_VAL,
70
+ shuffle=False)
71
+
72
+
73
+ def train(dataLoader, model, optimizer, epoch, loss_fn1, loss_fn2=None):
74
+ print(f'~~~ train for epoch {epoch} ~~~')
75
+ model.train()
76
+ loop = tqdm(dataLoader)
77
+ train_loss = 0
78
+ for i, (img, labels, name) in enumerate(loop):
79
+ #if torch.cuda.is_available():
80
+ labels = getTargetSegmentation(labels)
81
+ img, labels = img.to(DEVICE), labels.to(DEVICE)
82
+ yPred = model(img)
83
+ if loss_fn2!=None:
84
+ loss = 0.5*loss_fn1(yPred, labels) + 0.5*loss_fn2(yPred, labels)
85
+ else : loss = loss_fn1(yPred, labels)
86
+
87
+ train_loss += loss.item()
88
+
89
+ optimizer.zero_grad()
90
+ loss.backward()
91
+ optimizer.step()
92
+
93
+ loop.set_postfix(loss=loss.item()/len(dataLoader))
94
+ print('total train loss : {:.4f}\n'.format(train_loss/len(dataLoader.dataset)))
95
+ return model, train_loss/len(dataLoader.dataset)
96
+
97
+
98
+ def test(dataLoader, model, loss_fn, epoch):
99
+ print(f'~~~ validation for epoch {epoch} ~~~')
100
+ model.eval()
101
+ size = len(dataLoader)
102
+ loop = tqdm(dataLoader)
103
+ test_loss = 0
104
+ Acc = 0
105
+ Dsc1, Dsc2, Dsc3 = 0, 0, 0
106
+ IOU1, IOU2, IOU3 = 0, 0, 0
107
+ for i, (img, labels, name) in enumerate(loop):
108
+ #if torch.cuda.is_available():
109
+ labels = getTargetSegmentation(labels)
110
+ img, labels = img.to(DEVICE), labels.to(DEVICE)
111
+
112
+ yPred = model(img)
113
+ loss = loss_fn(yPred, labels)
114
+ test_loss += loss.item()
115
+ loop.set_postfix(loss=loss.item()/len(dataLoader))
116
+
117
+ Dsc = dice(yPred, labels, average='none', num_classes=4).cpu()
118
+ IOU = jaccard_index(yPred, labels, task='multiclass', average='none', num_classes=4).cpu()
119
+ Dsc1 += Dsc[1]
120
+ Dsc2 += Dsc[2]
121
+ Dsc3 += Dsc[3]
122
+ IOU1 += IOU[1]
123
+ IOU2 += IOU[2]
124
+ IOU3 += IOU[3]
125
+ print('total test loss : {:.4f}\nDice score 1 : {:.4f} | Dice score 2 : {:.4f} | Dice score 3 : {:.4f}\nIOU 1 : {:.4f} | IOU 2 : {:.4f} | IOU 3 : {:.4f}\n'.format(test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size))
126
+ return test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size
127
+
128
+
129
+ def main(train_loader, test_loader, model, optimizer, loss1, loss2):
130
+ train_loss_lst, test_loss_lst = [], []
131
+ Dsc1_lst, Dsc2_lst, Dsc3_lst = [], [], []
132
+ IOU1_lst, IOU2_lst, IOU3_lst = [], [], []
133
+ for i in range(EPOCHS):
134
+ model, train_loss = train(train_loader, model, optimizer, i+1, loss_fn1=loss1, loss_fn2=loss2)
135
+ test_loss, Dsc1, Dsc2, Dsc3, IOU1, IOU2, IOU3 = test(test_loader, model, loss1, i+1)
136
+ train_loss_lst.append(train_loss)
137
+ test_loss_lst.append(test_loss)
138
+ Dsc1_lst.append(Dsc1)
139
+ Dsc2_lst.append(Dsc2)
140
+ Dsc3_lst.append(Dsc3)
141
+ IOU1_lst.append(IOU1)
142
+ IOU2_lst.append(IOU2)
143
+ IOU3_lst.append(IOU3)
144
+
145
+ return model
146
+
147
+
148
+ if __name__=='__main__':
149
+ mdoel = main(train_loader, test_loader, model, optimizer, loss1=lossCE, loss2=lossDice)
150
+ plot_img(test_loader, 8, model)
src/medicalDataLoader.py CHANGED
@@ -110,4 +110,6 @@ class MedicalImageDataset(Dataset):
110
  img = self.transform(img)
111
  mask = self.mask_transform(mask)
112
 
113
- return [img, mask, img_path]
 
 
 
110
  img = self.transform(img)
111
  mask = self.mask_transform(mask)
112
 
113
+ return [img, mask, img_path]
114
+
115
+
src/utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ import torchvision
7
+ import os
8
+ import skimage.transform as skiTransf
9
+ import scipy.io as sio
10
+ import pdb
11
+ import time
12
+ import re
13
+ from os.path import isfile, join
14
+ import statistics
15
+ from PIL import Image
16
+ from medpy.metric.binary import dc, hd, asd, assd
17
+ import scipy.spatial
18
+ import matplotlib.pyplot as plt
19
+ from IPython.display import Image, display
20
+ from skimage import io
21
+ import cv2
22
+
23
+ # from scipy.spatial.distance import directed_hausdorff
24
+
25
+
26
+ labels = {0: 'Background', 1: 'Foreground'}
27
+
28
+
29
+ def computeDSC(pred, gt):
30
+ dscAll = []
31
+ #pdb.set_trace()
32
+ for i_b in range(pred.shape[0]):
33
+ pred_id = pred[i_b, 1, :]
34
+ gt_id = gt[i_b, 0, :]
35
+ dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy()))
36
+
37
+ DSC = np.asarray(dscAll)
38
+
39
+ return DSC.mean()
40
+
41
+
42
+ def getImageImageList(imagesFolder):
43
+ if os.path.exists(imagesFolder):
44
+ imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))]
45
+
46
+ imageNames.sort()
47
+
48
+ return imageNames
49
+
50
+
51
+ def to_var(x):
52
+ if torch.cuda.is_available():
53
+ x = x.cuda()
54
+ return Variable(x)
55
+
56
+
57
+ def DicesToDice(Dices):
58
+ sums = Dices.sum(dim=0)
59
+ return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8)
60
+
61
+
62
+ def predToSegmentation(pred):
63
+ Max = pred.max(dim=1, keepdim=True)[0]
64
+ x = pred / Max
65
+ # pdb.set_trace()
66
+ return (x == 1).float()
67
+
68
+
69
+ def getTargetSegmentation(batch):
70
+ # input is 1-channel of values between 0 and 1
71
+ # values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647
72
+ # output is 1 channel of discrete values : 0, 1, 2 and 3
73
+
74
+ denom = 0.33333334 # for ACDC this value
75
+ return (batch / denom).round().long().squeeze()
76
+
77
+
78
+ from scipy import ndimage
79
+
80
+
81
+ def inference(net, img_batch, modelName, epoch):
82
+ total = len(img_batch)
83
+ net.eval()
84
+
85
+ softMax = nn.Softmax().cuda()
86
+ CE_loss = nn.CrossEntropyLoss().cuda()
87
+
88
+ losses = []
89
+ for i, data in enumerate(img_batch):
90
+
91
+ printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30)
92
+ images, labels, img_names = data
93
+
94
+ images = to_var(images)
95
+ labels = to_var(labels)
96
+
97
+ net_predictions = net(images)
98
+ segmentation_classes = getTargetSegmentation(labels)
99
+ CE_loss_value = CE_loss(net_predictions, segmentation_classes)
100
+ losses.append(CE_loss_value.cpu().data.numpy())
101
+ pred_y = softMax(net_predictions)
102
+ masks = torch.argmax(pred_y, dim=1)
103
+
104
+ path = os.path.join('./Results/Images/', modelName, str(epoch))
105
+
106
+ if not os.path.exists(path):
107
+ os.makedirs(path)
108
+
109
+ torchvision.utils.save_image(
110
+ torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]),
111
+ os.path.join(path, str(i) + '.png'), padding=0)
112
+
113
+ printProgressBar(total, total, done="[Inference] Segmentation Done !")
114
+
115
+ losses = np.asarray(losses)
116
+
117
+ return losses.mean()
118
+
119
+
120
+ class MaskToTensor(object):
121
+ def __call__(self, img):
122
+ return torch.from_numpy(np.array(img, dtype=np.int32)).float()
123
+
124
+
125
+ def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
126
+ print("=> Saving checkpoint")
127
+ torch.save(state, filename)
128
+
129
+ def load_checkpoint(checkpoint, model):
130
+ print("=> Loading checkpoint")
131
+ model.load_state_dict(checkpoint["state_dict"])
132
+
133
+ def check_accuracy(loader, model, device="cuda"):
134
+ num_correct = 0
135
+ num_pixels = 0
136
+ dice_score = 0
137
+ model.eval()
138
+
139
+ with torch.no_grad():
140
+ for x, y in loader:
141
+ x = x.to(device)
142
+ y = y.to(device).unsqueeze(1)
143
+ preds = torch.sigmoid(model(x))
144
+ preds = (preds > 0.5).float()
145
+ num_correct += (preds == y).sum()
146
+ num_pixels += torch.numel(preds)
147
+ dice_score += (2 * (preds * y).sum()) / (
148
+ (preds + y).sum() + 1e-8
149
+ )
150
+
151
+ print(
152
+ f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
153
+ )
154
+ print(f"Dice score: {dice_score/len(loader)}")
155
+ model.train()
156
+
157
+ def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
158
+ model.eval()
159
+ for idx, (x, y) in enumerate(loader):
160
+ x = x.to(device=device)
161
+ with torch.no_grad():
162
+ preds = torch.sigmoid(model(x))
163
+ preds = (preds > 0.5).float()
164
+ torchvision.utils.save_image(
165
+ preds, f"{folder}/pred_{idx}.png"
166
+ )
167
+ torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
168
+
169
+ model.train()
170
+
171
+
172
+ # converting tensor to image
173
+ def image_convert(image):
174
+ image = image.clone().cpu().numpy()
175
+ image = image.transpose((1,2,0))
176
+ image = (image * 255)
177
+ return image
178
+
179
+ def mask_convert(mask):
180
+ mask = mask.clone().cpu().detach().numpy()
181
+ return np.squeeze(mask)
182
+
183
+ #If model is true, this will run inference on some test image and show the output on a plot
184
+ def plot_img(loader, no_, model=None):
185
+ images, target, name = next(iter(loader))
186
+ ind = np.random.choice(range(loader.batch_size))
187
+
188
+ data= to_var(images)
189
+
190
+ for idx in range(0,no_):
191
+ plt.figure(figsize=(12,12))
192
+
193
+ #Images
194
+ image = image_convert(images[idx])
195
+ plt.subplot(1,3,1)
196
+ plt.imshow(image)
197
+ plt.title('Original Image')
198
+
199
+ #Ground truth target mask
200
+ mask = mask_convert(target[idx])
201
+ plt.subplot(1,3,2)
202
+ plt.imshow(mask)
203
+ plt.title('Original Mask')
204
+
205
+ if model is None:
206
+ #superposition with target mask
207
+ plt.subplot(1,3,3)
208
+ plt.imshow(image)
209
+ plt.imshow(mask,alpha=0.6)
210
+ plt.title('Superposition')
211
+ else:
212
+ softMax = nn.Softmax().cuda()
213
+ #showing prediction mask
214
+ plt.subplot(1,3,3)
215
+ #make a prediction bases on the previous image
216
+ yhat = model(data)
217
+ pred_y = softMax(yhat)
218
+ masks = torch.argmax(pred_y, dim=1)
219
+ plt.imshow(mask_convert(masks[idx]))
220
+ plt.title('Prediction')
221
+ plt.show()
222
+
223
+
224
+
225
+
226
+ """
227
+ def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False):
228
+ train_transform = A.Compose(
229
+ [
230
+ A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
231
+ A.Rotate(limit=35, p=1.0),
232
+ A.HorizontalFlip(p=0.5),
233
+ A.VerticalFlip(p=0.1),
234
+ A.Normalize(
235
+ mean=[0.0, 0.0, 0.0],
236
+ std=[1.0, 1.0, 1.0],
237
+ max_pixel_value=255.0,
238
+ ),
239
+ ToTensorV2(),
240
+ ],
241
+ )
242
+
243
+ val_transform = A.Compose(
244
+ [
245
+ A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
246
+ A.Normalize(
247
+ mean=[0.0, 0.0, 0.0],
248
+ std=[1.0, 1.0, 1.0],
249
+ max_pixel_value=255.0,
250
+ ),
251
+ ToTensorV2(),
252
+ ],
253
+ )
254
+
255
+ ## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA
256
+ train_set_full = medicalDataLoader.MedicalImageDataset('train',
257
+ root_dir,
258
+ transform=train_transform,
259
+ mask_transform=train_transform,
260
+ augment=False,
261
+ equalize=False)
262
+
263
+ train_loader_full = DataLoader(train_set_full,
264
+ batch_size=batch_size,
265
+ worker_init_fn=np.random.seed(0),
266
+ num_workers= 0,
267
+ shuffle=True)
268
+
269
+ val_set = medicalDataLoader.MedicalImageDataset('val',
270
+ root_dir,
271
+ transform=val_transform,
272
+ mask_transform=val_transform,
273
+ equalize=False)
274
+
275
+ val_loader = DataLoader(val_set,
276
+ batch_size=batch_size,
277
+ worker_init_fn=np.random.seed(0),
278
+ num_workers = 0,
279
+ shuffle=False)
280
+
281
+ if test:
282
+ test_set = medicalDataLoader.MedicalImageDataset('test',
283
+ root_dir,
284
+ transform=None,
285
+ mask_transform=None,
286
+ equalize=False)
287
+
288
+ test_loader = DataLoader(test_set,
289
+ batch_size=batch_size,
290
+ num_workers=0,
291
+ shuffle=False)
292
+ return test_loader
293
+
294
+ return train_loader_full, val_loader"""