thov commited on
Commit
279daf2
·
1 Parent(s): 932f33d

add IOU and Dice Score

Browse files
Files changed (1) hide show
  1. app.py +23 -25
app.py CHANGED
@@ -9,13 +9,12 @@ import einops
9
  import torch
10
  from torch.utils.data import DataLoader
11
  from torchvision import transforms
12
- from torchmetrics.functional import dice
13
 
14
  from UNET_perso import UNET
15
  from src.medicalDataLoader import MedicalImageDataset
16
  from src.utils import getTargetSegmentation
17
 
18
-
19
  def disp_init(num):
20
  lst_patients = glob.glob('./Data/val/Img/*.png')
21
  patient = lst_patients[num]
@@ -25,7 +24,8 @@ def disp_init(num):
25
  plt.axis('off')
26
  return fig
27
 
28
- def disp_prediction(num, alpha=0.4):
 
29
  lst_patients = glob.glob('./Data/val/Img/*.png')
30
  patient = lst_patients[num]
31
 
@@ -43,7 +43,7 @@ def disp_prediction(num, alpha=0.4):
43
  equalize=False)
44
  val_loader = DataLoader(val_set,
45
  batch_size=1,
46
- shuffle=True)
47
 
48
  for _, (img, _, name) in enumerate(val_loader):
49
  if name[0] == patient:
@@ -60,7 +60,7 @@ def disp_prediction(num, alpha=0.4):
60
 
61
  fig = plt.figure()
62
  try:
63
- plt.imshow(alpha*pred_mask + (1-alpha)*im)
64
  except ValueError:
65
  print(type(pred_mask))
66
  plt.axis('off')
@@ -100,15 +100,14 @@ def compute_metrics(num):
100
  transform=transform,
101
  mask_transform=transform,
102
  equalize=False)
103
-
104
  val_loader = DataLoader(val_set,
105
  batch_size=1,
106
- shuffle=True)
107
 
108
  for _, (img, label, name) in enumerate(val_loader):
109
  if name[0] == patient:
110
  im = img
111
- label = label.reshape(256,256)
112
 
113
  pred = model(im)
114
 
@@ -116,30 +115,29 @@ def compute_metrics(num):
116
  pred = pred.reshape(4,256,256)
117
  pred = np.argmax(pred, axis=0)
118
  pred = torch.from_numpy(pred)
119
- pred = torch.nn.functional.one_hot(pred)
120
  pred = einops.rearrange(pred, 'h w class -> class h w')
121
 
122
- label = getTargetSegmentation(label)
123
- label = torch.nn.functional.one_hot(label, num_classes=4)
124
- label = einops.rearrange(label, 'h w class -> class h w')
125
-
126
- print(pred.shape, label.shape)
127
- #dsc = dice(pred, label, average='none', num_classes=4)
128
-
129
- print(torch.mean(pred.float()), torch.max(pred.float()))
130
- print(torch.mean(label.float()), torch.max(label.float()))
131
 
132
- #print(dsc)
133
- #dsc1 = dice_score(pred[1], label[1])
134
- #dsc2 = dice_score(pred[2], label[2])
135
- #dsc3 = dice_score(pred[3], label[3])
136
- #print(dsc1, dsc2, dsc3)
137
 
138
  df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
139
-
 
 
 
 
140
  return df
141
 
142
-
143
  with gr.Blocks() as demo:
144
  with gr.Row() as row1:
145
  slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
 
9
  import torch
10
  from torch.utils.data import DataLoader
11
  from torchvision import transforms
12
+ from torchmetrics.functional import dice_score, jaccard_index
13
 
14
  from UNET_perso import UNET
15
  from src.medicalDataLoader import MedicalImageDataset
16
  from src.utils import getTargetSegmentation
17
 
 
18
  def disp_init(num):
19
  lst_patients = glob.glob('./Data/val/Img/*.png')
20
  patient = lst_patients[num]
 
24
  plt.axis('off')
25
  return fig
26
 
27
+ def disp_prediction(num, alpha=0.45):
28
+
29
  lst_patients = glob.glob('./Data/val/Img/*.png')
30
  patient = lst_patients[num]
31
 
 
43
  equalize=False)
44
  val_loader = DataLoader(val_set,
45
  batch_size=1,
46
+ shuffle=False)
47
 
48
  for _, (img, _, name) in enumerate(val_loader):
49
  if name[0] == patient:
 
60
 
61
  fig = plt.figure()
62
  try:
63
+ plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
64
  except ValueError:
65
  print(type(pred_mask))
66
  plt.axis('off')
 
100
  transform=transform,
101
  mask_transform=transform,
102
  equalize=False)
 
103
  val_loader = DataLoader(val_set,
104
  batch_size=1,
105
+ shuffle=False)
106
 
107
  for _, (img, label, name) in enumerate(val_loader):
108
  if name[0] == patient:
109
  im = img
110
+ lab = label.reshape(256,256)
111
 
112
  pred = model(im)
113
 
 
115
  pred = pred.reshape(4,256,256)
116
  pred = np.argmax(pred, axis=0)
117
  pred = torch.from_numpy(pred)
118
+ pred = torch.nn.functional.one_hot(pred, num_classes=4)
119
  pred = einops.rearrange(pred, 'h w class -> class h w')
120
 
121
+ lab = getTargetSegmentation(lab)
122
+ lab = torch.nn.functional.one_hot(lab, num_classes=4)
123
+ lab = einops.rearrange(lab, 'h w class -> class h w')
124
+
125
+ dsc1 = dice_score(pred[1], lab[1])
126
+ dsc2 = dice_score(pred[2], lab[2])
127
+ dsc3 = dice_score(pred[3], lab[3])
 
 
128
 
129
+ iou1 = jaccard_index(pred[1], lab[1], num_classes=2)
130
+ iou2 = jaccard_index(pred[2], lab[2], num_classes=2)
131
+ iou3 = jaccard_index(pred[3], lab[3], num_classes=2)
 
 
132
 
133
  df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
134
+ df.loc[len(df)] = [round(dsc1.item(),2), round(dsc2.item(),2), round(dsc3.item(),2)]
135
+ df.loc[len(df)] = [round(iou1.item(),2), round(iou2.item(),2), round(iou3.item(),2)]
136
+ df = df.assign(metric=['Dice Score', 'IoU'])
137
+ df = df[['metric','class 1','class 2','class 3']]
138
+
139
  return df
140
 
 
141
  with gr.Blocks() as demo:
142
  with gr.Row() as row1:
143
  slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')