Spaces:
Sleeping
Sleeping
add training
Browse files- UNET_perso.py +75 -0
- main.py +150 -0
- src/medicalDataLoader.py +3 -1
- src/utils.py +294 -0
UNET_perso.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.transforms.functional as TF
|
4 |
+
|
5 |
+
|
6 |
+
#Aladdinpersson/Machine-Learning-Collection GIT
|
7 |
+
|
8 |
+
"""
|
9 |
+
Defining a UNet block
|
10 |
+
in_channels: image dimension
|
11 |
+
"""
|
12 |
+
class DoubleConv(nn.Module):
|
13 |
+
def __init__(self, in_channels, out_channels):
|
14 |
+
super(DoubleConv, self).__init__()
|
15 |
+
self.conv = nn.Sequential(
|
16 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
17 |
+
nn.BatchNorm2d(out_channels),
|
18 |
+
nn.ReLU(inplace=True),
|
19 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
|
20 |
+
nn.BatchNorm2d(out_channels),
|
21 |
+
nn.ReLU(inplace=True),
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.conv(x)
|
26 |
+
|
27 |
+
class UNET(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self, in_channels=3, out_channels=4, features=[64, 128, 256, 512]):
|
30 |
+
super(UNET, self).__init__()
|
31 |
+
self.ups = nn.ModuleList()
|
32 |
+
self.downs = nn.ModuleList()
|
33 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
34 |
+
|
35 |
+
# Down part of UNET
|
36 |
+
for feature in features:
|
37 |
+
self.downs.append(DoubleConv(in_channels, feature))
|
38 |
+
in_channels = feature
|
39 |
+
|
40 |
+
# Up part of UNET
|
41 |
+
for feature in reversed(features):
|
42 |
+
self.ups.append(
|
43 |
+
nn.ConvTranspose2d(
|
44 |
+
feature*2, feature, kernel_size=2, stride=2,
|
45 |
+
)
|
46 |
+
)
|
47 |
+
self.ups.append(DoubleConv(feature*2, feature))
|
48 |
+
|
49 |
+
#layer between down part and up part
|
50 |
+
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
|
51 |
+
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
skip_connections = []
|
55 |
+
|
56 |
+
for down in self.downs:
|
57 |
+
x = down(x)
|
58 |
+
skip_connections.append(x)
|
59 |
+
x = self.pool(x)
|
60 |
+
|
61 |
+
x = self.bottleneck(x)
|
62 |
+
skip_connections = skip_connections[::-1]
|
63 |
+
|
64 |
+
for idx in range(0, len(self.ups), 2):
|
65 |
+
x = self.ups[idx](x)
|
66 |
+
skip_connection = skip_connections[idx//2]
|
67 |
+
|
68 |
+
#Double check if input size is not divisible by 2, we need to be sure that the two shapes are similar
|
69 |
+
if x.shape != skip_connection.shape:
|
70 |
+
x = TF.resize(x, size=skip_connection.shape[2:])
|
71 |
+
|
72 |
+
concat_skip = torch.cat((skip_connection, x), dim=1)
|
73 |
+
x = self.ups[idx+1](concat_skip)
|
74 |
+
|
75 |
+
return self.final_conv(x)
|
main.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchmetrics.functional import dice, jaccard_index, accuracy
|
12 |
+
|
13 |
+
from segmentation_models_pytorch.losses import DiceLoss, TverskyLoss, FocalLoss
|
14 |
+
|
15 |
+
from src.medicalDataLoader import MedicalImageDataset
|
16 |
+
from src.utils import getTargetSegmentation, plot_img
|
17 |
+
from UNET_perso import UNET
|
18 |
+
|
19 |
+
## Parameters & Hyperparameters ##
|
20 |
+
EPOCHS = 2
|
21 |
+
BATCH_SIZE_TRAIN = 8
|
22 |
+
BATCH_SIZE_VAL = 8
|
23 |
+
LR = 1e-3
|
24 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
|
27 |
+
## Model ##
|
28 |
+
model = UNET(in_channels=1, out_channels=4).to(DEVICE)
|
29 |
+
|
30 |
+
## Loss ##
|
31 |
+
lossCE = nn.CrossEntropyLoss().to(DEVICE)
|
32 |
+
lossDice = DiceLoss(mode='multiclass').to(DEVICE)
|
33 |
+
|
34 |
+
## optimizer ##
|
35 |
+
#optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
36 |
+
optimizer = torch.optim.NAdam(model.parameters(), lr=LR)
|
37 |
+
|
38 |
+
|
39 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
40 |
+
ROOT_DIR = './Data'
|
41 |
+
train_set = MedicalImageDataset('train',
|
42 |
+
ROOT_DIR,
|
43 |
+
transform=transform,
|
44 |
+
mask_transform=transform,
|
45 |
+
augment=True,
|
46 |
+
equalize=False)
|
47 |
+
|
48 |
+
train_loader = DataLoader(train_set,
|
49 |
+
batch_size=BATCH_SIZE_TRAIN,
|
50 |
+
shuffle=True)
|
51 |
+
|
52 |
+
val_set = MedicalImageDataset('val',
|
53 |
+
ROOT_DIR,
|
54 |
+
transform=transform,
|
55 |
+
mask_transform=transform,
|
56 |
+
equalize=False)
|
57 |
+
|
58 |
+
val_loader = DataLoader(val_set,
|
59 |
+
batch_size=BATCH_SIZE_VAL,
|
60 |
+
shuffle=False)
|
61 |
+
|
62 |
+
test_set = MedicalImageDataset('test',
|
63 |
+
ROOT_DIR,
|
64 |
+
transform=transform,
|
65 |
+
mask_transform=transform,
|
66 |
+
equalize=False)
|
67 |
+
|
68 |
+
test_loader = DataLoader(test_set,
|
69 |
+
batch_size=BATCH_SIZE_VAL,
|
70 |
+
shuffle=False)
|
71 |
+
|
72 |
+
|
73 |
+
def train(dataLoader, model, optimizer, epoch, loss_fn1, loss_fn2=None):
|
74 |
+
print(f'~~~ train for epoch {epoch} ~~~')
|
75 |
+
model.train()
|
76 |
+
loop = tqdm(dataLoader)
|
77 |
+
train_loss = 0
|
78 |
+
for i, (img, labels, name) in enumerate(loop):
|
79 |
+
#if torch.cuda.is_available():
|
80 |
+
labels = getTargetSegmentation(labels)
|
81 |
+
img, labels = img.to(DEVICE), labels.to(DEVICE)
|
82 |
+
yPred = model(img)
|
83 |
+
if loss_fn2!=None:
|
84 |
+
loss = 0.5*loss_fn1(yPred, labels) + 0.5*loss_fn2(yPred, labels)
|
85 |
+
else : loss = loss_fn1(yPred, labels)
|
86 |
+
|
87 |
+
train_loss += loss.item()
|
88 |
+
|
89 |
+
optimizer.zero_grad()
|
90 |
+
loss.backward()
|
91 |
+
optimizer.step()
|
92 |
+
|
93 |
+
loop.set_postfix(loss=loss.item()/len(dataLoader))
|
94 |
+
print('total train loss : {:.4f}\n'.format(train_loss/len(dataLoader.dataset)))
|
95 |
+
return model, train_loss/len(dataLoader.dataset)
|
96 |
+
|
97 |
+
|
98 |
+
def test(dataLoader, model, loss_fn, epoch):
|
99 |
+
print(f'~~~ validation for epoch {epoch} ~~~')
|
100 |
+
model.eval()
|
101 |
+
size = len(dataLoader)
|
102 |
+
loop = tqdm(dataLoader)
|
103 |
+
test_loss = 0
|
104 |
+
Acc = 0
|
105 |
+
Dsc1, Dsc2, Dsc3 = 0, 0, 0
|
106 |
+
IOU1, IOU2, IOU3 = 0, 0, 0
|
107 |
+
for i, (img, labels, name) in enumerate(loop):
|
108 |
+
#if torch.cuda.is_available():
|
109 |
+
labels = getTargetSegmentation(labels)
|
110 |
+
img, labels = img.to(DEVICE), labels.to(DEVICE)
|
111 |
+
|
112 |
+
yPred = model(img)
|
113 |
+
loss = loss_fn(yPred, labels)
|
114 |
+
test_loss += loss.item()
|
115 |
+
loop.set_postfix(loss=loss.item()/len(dataLoader))
|
116 |
+
|
117 |
+
Dsc = dice(yPred, labels, average='none', num_classes=4).cpu()
|
118 |
+
IOU = jaccard_index(yPred, labels, task='multiclass', average='none', num_classes=4).cpu()
|
119 |
+
Dsc1 += Dsc[1]
|
120 |
+
Dsc2 += Dsc[2]
|
121 |
+
Dsc3 += Dsc[3]
|
122 |
+
IOU1 += IOU[1]
|
123 |
+
IOU2 += IOU[2]
|
124 |
+
IOU3 += IOU[3]
|
125 |
+
print('total test loss : {:.4f}\nDice score 1 : {:.4f} | Dice score 2 : {:.4f} | Dice score 3 : {:.4f}\nIOU 1 : {:.4f} | IOU 2 : {:.4f} | IOU 3 : {:.4f}\n'.format(test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size))
|
126 |
+
return test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size
|
127 |
+
|
128 |
+
|
129 |
+
def main(train_loader, test_loader, model, optimizer, loss1, loss2):
|
130 |
+
train_loss_lst, test_loss_lst = [], []
|
131 |
+
Dsc1_lst, Dsc2_lst, Dsc3_lst = [], [], []
|
132 |
+
IOU1_lst, IOU2_lst, IOU3_lst = [], [], []
|
133 |
+
for i in range(EPOCHS):
|
134 |
+
model, train_loss = train(train_loader, model, optimizer, i+1, loss_fn1=loss1, loss_fn2=loss2)
|
135 |
+
test_loss, Dsc1, Dsc2, Dsc3, IOU1, IOU2, IOU3 = test(test_loader, model, loss1, i+1)
|
136 |
+
train_loss_lst.append(train_loss)
|
137 |
+
test_loss_lst.append(test_loss)
|
138 |
+
Dsc1_lst.append(Dsc1)
|
139 |
+
Dsc2_lst.append(Dsc2)
|
140 |
+
Dsc3_lst.append(Dsc3)
|
141 |
+
IOU1_lst.append(IOU1)
|
142 |
+
IOU2_lst.append(IOU2)
|
143 |
+
IOU3_lst.append(IOU3)
|
144 |
+
|
145 |
+
return model
|
146 |
+
|
147 |
+
|
148 |
+
if __name__=='__main__':
|
149 |
+
mdoel = main(train_loader, test_loader, model, optimizer, loss1=lossCE, loss2=lossDice)
|
150 |
+
plot_img(test_loader, 8, model)
|
src/medicalDataLoader.py
CHANGED
@@ -110,4 +110,6 @@ class MedicalImageDataset(Dataset):
|
|
110 |
img = self.transform(img)
|
111 |
mask = self.mask_transform(mask)
|
112 |
|
113 |
-
return [img, mask, img_path]
|
|
|
|
|
|
110 |
img = self.transform(img)
|
111 |
mask = self.mask_transform(mask)
|
112 |
|
113 |
+
return [img, mask, img_path]
|
114 |
+
|
115 |
+
|
src/utils.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import torchvision
|
7 |
+
import os
|
8 |
+
import skimage.transform as skiTransf
|
9 |
+
import scipy.io as sio
|
10 |
+
import pdb
|
11 |
+
import time
|
12 |
+
import re
|
13 |
+
from os.path import isfile, join
|
14 |
+
import statistics
|
15 |
+
from PIL import Image
|
16 |
+
from medpy.metric.binary import dc, hd, asd, assd
|
17 |
+
import scipy.spatial
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from IPython.display import Image, display
|
20 |
+
from skimage import io
|
21 |
+
import cv2
|
22 |
+
|
23 |
+
# from scipy.spatial.distance import directed_hausdorff
|
24 |
+
|
25 |
+
|
26 |
+
labels = {0: 'Background', 1: 'Foreground'}
|
27 |
+
|
28 |
+
|
29 |
+
def computeDSC(pred, gt):
|
30 |
+
dscAll = []
|
31 |
+
#pdb.set_trace()
|
32 |
+
for i_b in range(pred.shape[0]):
|
33 |
+
pred_id = pred[i_b, 1, :]
|
34 |
+
gt_id = gt[i_b, 0, :]
|
35 |
+
dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy()))
|
36 |
+
|
37 |
+
DSC = np.asarray(dscAll)
|
38 |
+
|
39 |
+
return DSC.mean()
|
40 |
+
|
41 |
+
|
42 |
+
def getImageImageList(imagesFolder):
|
43 |
+
if os.path.exists(imagesFolder):
|
44 |
+
imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))]
|
45 |
+
|
46 |
+
imageNames.sort()
|
47 |
+
|
48 |
+
return imageNames
|
49 |
+
|
50 |
+
|
51 |
+
def to_var(x):
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
x = x.cuda()
|
54 |
+
return Variable(x)
|
55 |
+
|
56 |
+
|
57 |
+
def DicesToDice(Dices):
|
58 |
+
sums = Dices.sum(dim=0)
|
59 |
+
return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8)
|
60 |
+
|
61 |
+
|
62 |
+
def predToSegmentation(pred):
|
63 |
+
Max = pred.max(dim=1, keepdim=True)[0]
|
64 |
+
x = pred / Max
|
65 |
+
# pdb.set_trace()
|
66 |
+
return (x == 1).float()
|
67 |
+
|
68 |
+
|
69 |
+
def getTargetSegmentation(batch):
|
70 |
+
# input is 1-channel of values between 0 and 1
|
71 |
+
# values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647
|
72 |
+
# output is 1 channel of discrete values : 0, 1, 2 and 3
|
73 |
+
|
74 |
+
denom = 0.33333334 # for ACDC this value
|
75 |
+
return (batch / denom).round().long().squeeze()
|
76 |
+
|
77 |
+
|
78 |
+
from scipy import ndimage
|
79 |
+
|
80 |
+
|
81 |
+
def inference(net, img_batch, modelName, epoch):
|
82 |
+
total = len(img_batch)
|
83 |
+
net.eval()
|
84 |
+
|
85 |
+
softMax = nn.Softmax().cuda()
|
86 |
+
CE_loss = nn.CrossEntropyLoss().cuda()
|
87 |
+
|
88 |
+
losses = []
|
89 |
+
for i, data in enumerate(img_batch):
|
90 |
+
|
91 |
+
printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30)
|
92 |
+
images, labels, img_names = data
|
93 |
+
|
94 |
+
images = to_var(images)
|
95 |
+
labels = to_var(labels)
|
96 |
+
|
97 |
+
net_predictions = net(images)
|
98 |
+
segmentation_classes = getTargetSegmentation(labels)
|
99 |
+
CE_loss_value = CE_loss(net_predictions, segmentation_classes)
|
100 |
+
losses.append(CE_loss_value.cpu().data.numpy())
|
101 |
+
pred_y = softMax(net_predictions)
|
102 |
+
masks = torch.argmax(pred_y, dim=1)
|
103 |
+
|
104 |
+
path = os.path.join('./Results/Images/', modelName, str(epoch))
|
105 |
+
|
106 |
+
if not os.path.exists(path):
|
107 |
+
os.makedirs(path)
|
108 |
+
|
109 |
+
torchvision.utils.save_image(
|
110 |
+
torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]),
|
111 |
+
os.path.join(path, str(i) + '.png'), padding=0)
|
112 |
+
|
113 |
+
printProgressBar(total, total, done="[Inference] Segmentation Done !")
|
114 |
+
|
115 |
+
losses = np.asarray(losses)
|
116 |
+
|
117 |
+
return losses.mean()
|
118 |
+
|
119 |
+
|
120 |
+
class MaskToTensor(object):
|
121 |
+
def __call__(self, img):
|
122 |
+
return torch.from_numpy(np.array(img, dtype=np.int32)).float()
|
123 |
+
|
124 |
+
|
125 |
+
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
126 |
+
print("=> Saving checkpoint")
|
127 |
+
torch.save(state, filename)
|
128 |
+
|
129 |
+
def load_checkpoint(checkpoint, model):
|
130 |
+
print("=> Loading checkpoint")
|
131 |
+
model.load_state_dict(checkpoint["state_dict"])
|
132 |
+
|
133 |
+
def check_accuracy(loader, model, device="cuda"):
|
134 |
+
num_correct = 0
|
135 |
+
num_pixels = 0
|
136 |
+
dice_score = 0
|
137 |
+
model.eval()
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
for x, y in loader:
|
141 |
+
x = x.to(device)
|
142 |
+
y = y.to(device).unsqueeze(1)
|
143 |
+
preds = torch.sigmoid(model(x))
|
144 |
+
preds = (preds > 0.5).float()
|
145 |
+
num_correct += (preds == y).sum()
|
146 |
+
num_pixels += torch.numel(preds)
|
147 |
+
dice_score += (2 * (preds * y).sum()) / (
|
148 |
+
(preds + y).sum() + 1e-8
|
149 |
+
)
|
150 |
+
|
151 |
+
print(
|
152 |
+
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
|
153 |
+
)
|
154 |
+
print(f"Dice score: {dice_score/len(loader)}")
|
155 |
+
model.train()
|
156 |
+
|
157 |
+
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
|
158 |
+
model.eval()
|
159 |
+
for idx, (x, y) in enumerate(loader):
|
160 |
+
x = x.to(device=device)
|
161 |
+
with torch.no_grad():
|
162 |
+
preds = torch.sigmoid(model(x))
|
163 |
+
preds = (preds > 0.5).float()
|
164 |
+
torchvision.utils.save_image(
|
165 |
+
preds, f"{folder}/pred_{idx}.png"
|
166 |
+
)
|
167 |
+
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
|
168 |
+
|
169 |
+
model.train()
|
170 |
+
|
171 |
+
|
172 |
+
# converting tensor to image
|
173 |
+
def image_convert(image):
|
174 |
+
image = image.clone().cpu().numpy()
|
175 |
+
image = image.transpose((1,2,0))
|
176 |
+
image = (image * 255)
|
177 |
+
return image
|
178 |
+
|
179 |
+
def mask_convert(mask):
|
180 |
+
mask = mask.clone().cpu().detach().numpy()
|
181 |
+
return np.squeeze(mask)
|
182 |
+
|
183 |
+
#If model is true, this will run inference on some test image and show the output on a plot
|
184 |
+
def plot_img(loader, no_, model=None):
|
185 |
+
images, target, name = next(iter(loader))
|
186 |
+
ind = np.random.choice(range(loader.batch_size))
|
187 |
+
|
188 |
+
data= to_var(images)
|
189 |
+
|
190 |
+
for idx in range(0,no_):
|
191 |
+
plt.figure(figsize=(12,12))
|
192 |
+
|
193 |
+
#Images
|
194 |
+
image = image_convert(images[idx])
|
195 |
+
plt.subplot(1,3,1)
|
196 |
+
plt.imshow(image)
|
197 |
+
plt.title('Original Image')
|
198 |
+
|
199 |
+
#Ground truth target mask
|
200 |
+
mask = mask_convert(target[idx])
|
201 |
+
plt.subplot(1,3,2)
|
202 |
+
plt.imshow(mask)
|
203 |
+
plt.title('Original Mask')
|
204 |
+
|
205 |
+
if model is None:
|
206 |
+
#superposition with target mask
|
207 |
+
plt.subplot(1,3,3)
|
208 |
+
plt.imshow(image)
|
209 |
+
plt.imshow(mask,alpha=0.6)
|
210 |
+
plt.title('Superposition')
|
211 |
+
else:
|
212 |
+
softMax = nn.Softmax().cuda()
|
213 |
+
#showing prediction mask
|
214 |
+
plt.subplot(1,3,3)
|
215 |
+
#make a prediction bases on the previous image
|
216 |
+
yhat = model(data)
|
217 |
+
pred_y = softMax(yhat)
|
218 |
+
masks = torch.argmax(pred_y, dim=1)
|
219 |
+
plt.imshow(mask_convert(masks[idx]))
|
220 |
+
plt.title('Prediction')
|
221 |
+
plt.show()
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
"""
|
227 |
+
def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False):
|
228 |
+
train_transform = A.Compose(
|
229 |
+
[
|
230 |
+
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
|
231 |
+
A.Rotate(limit=35, p=1.0),
|
232 |
+
A.HorizontalFlip(p=0.5),
|
233 |
+
A.VerticalFlip(p=0.1),
|
234 |
+
A.Normalize(
|
235 |
+
mean=[0.0, 0.0, 0.0],
|
236 |
+
std=[1.0, 1.0, 1.0],
|
237 |
+
max_pixel_value=255.0,
|
238 |
+
),
|
239 |
+
ToTensorV2(),
|
240 |
+
],
|
241 |
+
)
|
242 |
+
|
243 |
+
val_transform = A.Compose(
|
244 |
+
[
|
245 |
+
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
|
246 |
+
A.Normalize(
|
247 |
+
mean=[0.0, 0.0, 0.0],
|
248 |
+
std=[1.0, 1.0, 1.0],
|
249 |
+
max_pixel_value=255.0,
|
250 |
+
),
|
251 |
+
ToTensorV2(),
|
252 |
+
],
|
253 |
+
)
|
254 |
+
|
255 |
+
## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA
|
256 |
+
train_set_full = medicalDataLoader.MedicalImageDataset('train',
|
257 |
+
root_dir,
|
258 |
+
transform=train_transform,
|
259 |
+
mask_transform=train_transform,
|
260 |
+
augment=False,
|
261 |
+
equalize=False)
|
262 |
+
|
263 |
+
train_loader_full = DataLoader(train_set_full,
|
264 |
+
batch_size=batch_size,
|
265 |
+
worker_init_fn=np.random.seed(0),
|
266 |
+
num_workers= 0,
|
267 |
+
shuffle=True)
|
268 |
+
|
269 |
+
val_set = medicalDataLoader.MedicalImageDataset('val',
|
270 |
+
root_dir,
|
271 |
+
transform=val_transform,
|
272 |
+
mask_transform=val_transform,
|
273 |
+
equalize=False)
|
274 |
+
|
275 |
+
val_loader = DataLoader(val_set,
|
276 |
+
batch_size=batch_size,
|
277 |
+
worker_init_fn=np.random.seed(0),
|
278 |
+
num_workers = 0,
|
279 |
+
shuffle=False)
|
280 |
+
|
281 |
+
if test:
|
282 |
+
test_set = medicalDataLoader.MedicalImageDataset('test',
|
283 |
+
root_dir,
|
284 |
+
transform=None,
|
285 |
+
mask_transform=None,
|
286 |
+
equalize=False)
|
287 |
+
|
288 |
+
test_loader = DataLoader(test_set,
|
289 |
+
batch_size=batch_size,
|
290 |
+
num_workers=0,
|
291 |
+
shuffle=False)
|
292 |
+
return test_loader
|
293 |
+
|
294 |
+
return train_loader_full, val_loader"""
|