thov commited on
Commit
32278ee
1 Parent(s): 279daf2

implementation perso of dice and iou

Browse files
Files changed (2) hide show
  1. app.py +10 -13
  2. src/utils.py +22 -270
app.py CHANGED
@@ -9,11 +9,10 @@ import einops
9
  import torch
10
  from torch.utils.data import DataLoader
11
  from torchvision import transforms
12
- from torchmetrics.functional import dice_score, jaccard_index
13
 
14
  from UNET_perso import UNET
15
  from src.medicalDataLoader import MedicalImageDataset
16
- from src.utils import getTargetSegmentation
17
 
18
  def disp_init(num):
19
  lst_patients = glob.glob('./Data/val/Img/*.png')
@@ -62,22 +61,20 @@ def disp_prediction(num, alpha=0.45):
62
  try:
63
  plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
64
  except ValueError:
65
- print(type(pred_mask))
66
  plt.axis('off')
67
 
68
  return fig
69
 
70
- def disp_label(num, alpha=0.4):
71
  lst_GT = glob.glob('./Data/val/GT/*.png')
72
  GT = lst_GT[num]
73
-
74
  label = iio.imread(GT)
75
-
76
  fig = plt.figure()
77
- try :
78
  plt.imshow(label)
79
  except ValueError:
80
- print(type(label))
81
  plt.axis('off')
82
 
83
  return fig
@@ -126,13 +123,13 @@ def compute_metrics(num):
126
  dsc2 = dice_score(pred[2], lab[2])
127
  dsc3 = dice_score(pred[3], lab[3])
128
 
129
- iou1 = jaccard_index(pred[1], lab[1], num_classes=2)
130
- iou2 = jaccard_index(pred[2], lab[2], num_classes=2)
131
- iou3 = jaccard_index(pred[3], lab[3], num_classes=2)
132
 
133
  df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
134
- df.loc[len(df)] = [round(dsc1.item(),2), round(dsc2.item(),2), round(dsc3.item(),2)]
135
- df.loc[len(df)] = [round(iou1.item(),2), round(iou2.item(),2), round(iou3.item(),2)]
136
  df = df.assign(metric=['Dice Score', 'IoU'])
137
  df = df[['metric','class 1','class 2','class 3']]
138
 
 
9
  import torch
10
  from torch.utils.data import DataLoader
11
  from torchvision import transforms
 
12
 
13
  from UNET_perso import UNET
14
  from src.medicalDataLoader import MedicalImageDataset
15
+ from src.utils import getTargetSegmentation, IOU, dice_score
16
 
17
  def disp_init(num):
18
  lst_patients = glob.glob('./Data/val/Img/*.png')
 
61
  try:
62
  plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
63
  except ValueError:
64
+ pass
65
  plt.axis('off')
66
 
67
  return fig
68
 
69
+ def disp_label(num):
70
  lst_GT = glob.glob('./Data/val/GT/*.png')
71
  GT = lst_GT[num]
 
72
  label = iio.imread(GT)
 
73
  fig = plt.figure()
74
+ try:
75
  plt.imshow(label)
76
  except ValueError:
77
+ pass
78
  plt.axis('off')
79
 
80
  return fig
 
123
  dsc2 = dice_score(pred[2], lab[2])
124
  dsc3 = dice_score(pred[3], lab[3])
125
 
126
+ iou1 = IOU(pred[1], lab[1])
127
+ iou2 = IOU(pred[2], lab[2])
128
+ iou3 = IOU(pred[3], lab[3])
129
 
130
  df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
131
+ df.loc[len(df)] = [round(dsc1,2), round(dsc2,2), round(dsc3,2)]
132
+ df.loc[len(df)] = [round(iou1,2), round(iou2,2), round(iou3,2)]
133
  df = df.assign(metric=['Dice Score', 'IoU'])
134
  df = df[['metric','class 1','class 2','class 3']]
135
 
src/utils.py CHANGED
@@ -1,56 +1,27 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- from torch.autograd import Variable
5
- import torchvision
6
- import os
7
- from os.path import isfile, join
8
  from medpy.metric.binary import dc, hd, asd, assd
9
- import matplotlib.pyplot as plt
10
- from IPython.display import Image, display
11
 
12
- labels = {0: 'Background', 1: 'Foreground'}
13
-
14
-
15
- def computeDSC(pred, gt):
16
- dscAll = []
17
- #pdb.set_trace()
18
- for i_b in range(pred.shape[0]):
19
- pred_id = pred[i_b, 1, :]
20
- gt_id = gt[i_b, 0, :]
21
- dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy()))
22
-
23
- DSC = np.asarray(dscAll)
24
-
25
- return DSC.mean()
26
-
27
-
28
- def getImageImageList(imagesFolder):
29
- if os.path.exists(imagesFolder):
30
- imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))]
31
-
32
- imageNames.sort()
33
-
34
- return imageNames
35
-
36
-
37
- def to_var(x):
38
- if torch.cuda.is_available():
39
- x = x.cuda()
40
- return Variable(x)
41
-
42
-
43
- def DicesToDice(Dices):
44
- sums = Dices.sum(dim=0)
45
- return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8)
46
-
47
-
48
- def predToSegmentation(pred):
49
- Max = pred.max(dim=1, keepdim=True)[0]
50
- x = pred / Max
51
- # pdb.set_trace()
52
- return (x == 1).float()
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def getTargetSegmentation(batch):
56
  # input is 1-channel of values between 0 and 1
@@ -58,223 +29,4 @@ def getTargetSegmentation(batch):
58
  # output is 1 channel of discrete values : 0, 1, 2 and 3
59
 
60
  denom = 0.33333334 # for ACDC this value
61
- return (batch / denom).round().long().squeeze()
62
-
63
-
64
- from scipy import ndimage
65
-
66
-
67
- def inference(net, img_batch, modelName, epoch):
68
- total = len(img_batch)
69
- net.eval()
70
-
71
- softMax = nn.Softmax().cuda()
72
- CE_loss = nn.CrossEntropyLoss().cuda()
73
-
74
- losses = []
75
- for i, data in enumerate(img_batch):
76
-
77
- printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30)
78
- images, labels, img_names = data
79
-
80
- images = to_var(images)
81
- labels = to_var(labels)
82
-
83
- net_predictions = net(images)
84
- segmentation_classes = getTargetSegmentation(labels)
85
- CE_loss_value = CE_loss(net_predictions, segmentation_classes)
86
- losses.append(CE_loss_value.cpu().data.numpy())
87
- pred_y = softMax(net_predictions)
88
- masks = torch.argmax(pred_y, dim=1)
89
-
90
- path = os.path.join('./Results/Images/', modelName, str(epoch))
91
-
92
- if not os.path.exists(path):
93
- os.makedirs(path)
94
-
95
- torchvision.utils.save_image(
96
- torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]),
97
- os.path.join(path, str(i) + '.png'), padding=0)
98
-
99
- printProgressBar(total, total, done="[Inference] Segmentation Done !")
100
-
101
- losses = np.asarray(losses)
102
-
103
- return losses.mean()
104
-
105
-
106
- class MaskToTensor(object):
107
- def __call__(self, img):
108
- return torch.from_numpy(np.array(img, dtype=np.int32)).float()
109
-
110
-
111
- def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
112
- print("=> Saving checkpoint")
113
- torch.save(state, filename)
114
-
115
- def load_checkpoint(checkpoint, model):
116
- print("=> Loading checkpoint")
117
- model.load_state_dict(checkpoint["state_dict"])
118
-
119
- def check_accuracy(loader, model, device="cuda"):
120
- num_correct = 0
121
- num_pixels = 0
122
- dice_score = 0
123
- model.eval()
124
-
125
- with torch.no_grad():
126
- for x, y in loader:
127
- x = x.to(device)
128
- y = y.to(device).unsqueeze(1)
129
- preds = torch.sigmoid(model(x))
130
- preds = (preds > 0.5).float()
131
- num_correct += (preds == y).sum()
132
- num_pixels += torch.numel(preds)
133
- dice_score += (2 * (preds * y).sum()) / (
134
- (preds + y).sum() + 1e-8
135
- )
136
-
137
- print(
138
- f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
139
- )
140
- print(f"Dice score: {dice_score/len(loader)}")
141
- model.train()
142
-
143
- def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
144
- model.eval()
145
- for idx, (x, y) in enumerate(loader):
146
- x = x.to(device=device)
147
- with torch.no_grad():
148
- preds = torch.sigmoid(model(x))
149
- preds = (preds > 0.5).float()
150
- torchvision.utils.save_image(
151
- preds, f"{folder}/pred_{idx}.png"
152
- )
153
- torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
154
-
155
- model.train()
156
-
157
-
158
- # converting tensor to image
159
- def image_convert(image):
160
- image = image.clone().cpu().numpy()
161
- image = image.transpose((1,2,0))
162
- image = (image * 255)
163
- return image
164
-
165
- def mask_convert(mask):
166
- mask = mask.clone().cpu().detach().numpy()
167
- return np.squeeze(mask)
168
-
169
- #If model is true, this will run inference on some test image and show the output on a plot
170
- def plot_img(loader, no_, model=None):
171
- images, target, name = next(iter(loader))
172
- ind = np.random.choice(range(loader.batch_size))
173
-
174
- data= to_var(images)
175
-
176
- for idx in range(0,no_):
177
- plt.figure(figsize=(12,12))
178
-
179
- #Images
180
- image = image_convert(images[idx])
181
- plt.subplot(1,3,1)
182
- plt.imshow(image)
183
- plt.title('Original Image')
184
-
185
- #Ground truth target mask
186
- mask = mask_convert(target[idx])
187
- plt.subplot(1,3,2)
188
- plt.imshow(mask)
189
- plt.title('Original Mask')
190
-
191
- if model is None:
192
- #superposition with target mask
193
- plt.subplot(1,3,3)
194
- plt.imshow(image)
195
- plt.imshow(mask,alpha=0.6)
196
- plt.title('Superposition')
197
- else:
198
- softMax = nn.Softmax().cuda()
199
- #showing prediction mask
200
- plt.subplot(1,3,3)
201
- #make a prediction bases on the previous image
202
- yhat = model(data)
203
- pred_y = softMax(yhat)
204
- masks = torch.argmax(pred_y, dim=1)
205
- plt.imshow(mask_convert(masks[idx]))
206
- plt.title('Prediction')
207
- plt.show()
208
-
209
-
210
-
211
-
212
- """
213
- def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False):
214
- train_transform = A.Compose(
215
- [
216
- A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
217
- A.Rotate(limit=35, p=1.0),
218
- A.HorizontalFlip(p=0.5),
219
- A.VerticalFlip(p=0.1),
220
- A.Normalize(
221
- mean=[0.0, 0.0, 0.0],
222
- std=[1.0, 1.0, 1.0],
223
- max_pixel_value=255.0,
224
- ),
225
- ToTensorV2(),
226
- ],
227
- )
228
-
229
- val_transform = A.Compose(
230
- [
231
- A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
232
- A.Normalize(
233
- mean=[0.0, 0.0, 0.0],
234
- std=[1.0, 1.0, 1.0],
235
- max_pixel_value=255.0,
236
- ),
237
- ToTensorV2(),
238
- ],
239
- )
240
-
241
- ## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA
242
- train_set_full = medicalDataLoader.MedicalImageDataset('train',
243
- root_dir,
244
- transform=train_transform,
245
- mask_transform=train_transform,
246
- augment=False,
247
- equalize=False)
248
-
249
- train_loader_full = DataLoader(train_set_full,
250
- batch_size=batch_size,
251
- worker_init_fn=np.random.seed(0),
252
- num_workers= 0,
253
- shuffle=True)
254
-
255
- val_set = medicalDataLoader.MedicalImageDataset('val',
256
- root_dir,
257
- transform=val_transform,
258
- mask_transform=val_transform,
259
- equalize=False)
260
-
261
- val_loader = DataLoader(val_set,
262
- batch_size=batch_size,
263
- worker_init_fn=np.random.seed(0),
264
- num_workers = 0,
265
- shuffle=False)
266
-
267
- if test:
268
- test_set = medicalDataLoader.MedicalImageDataset('test',
269
- root_dir,
270
- transform=None,
271
- mask_transform=None,
272
- equalize=False)
273
-
274
- test_loader = DataLoader(test_set,
275
- batch_size=batch_size,
276
- num_workers=0,
277
- shuffle=False)
278
- return test_loader
279
-
280
- return train_loader_full, val_loader"""
 
1
+ from itertools import chain
 
 
 
 
 
 
2
  from medpy.metric.binary import dc, hd, asd, assd
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ def IOU(pred, label):
6
+ sum_ = pred+label
7
+ overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
8
+ union = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==1])
9
+ try:
10
+ iou = overlap/(union+overlap)
11
+ except ZeroDivisionError:
12
+ iou = 0
13
+ return iou
14
+
15
+ def dice_score(pred, label):
16
+ sum_ = pred+label
17
+ overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
18
+ predAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
19
+ labelAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
20
+ try:
21
+ ds = (2*overlap)/(predAera+labelAera)
22
+ except ZeroDivisionError:
23
+ ds = 0
24
+ return ds
25
 
26
  def getTargetSegmentation(batch):
27
  # input is 1-channel of values between 0 and 1
 
29
  # output is 1 channel of discrete values : 0, 1, 2 and 3
30
 
31
  denom = 0.33333334 # for ACDC this value
32
+ return (batch / denom).round().long().squeeze()