thov commited on
Commit
cb7c3ed
1 Parent(s): 7014b1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -16
app.py CHANGED
@@ -1,28 +1,73 @@
 
1
  import gradio as gr
 
2
  import torch
 
 
 
 
3
  from UNET_perso import UNET
4
  import matplotlib.pyplot as plt
5
- from src.medicalDataLoader import *
6
-
7
 
8
- def greet(name):
9
- return "Hello " + name + "!"
10
 
11
- import torchvision
12
- from torchvision.io import read_image
13
-
14
- pic = read_image('Data/val/Img/patient001_01_1.png')
15
- print(pic.shape)
16
 
17
- model = UNET(in_channels=1, out_channels=4).to('cpu')
 
18
 
19
- filepath = 'UNET_perso'
20
- model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
21
- model.eval()
22
 
23
- plt.imshow(pic)
24
- plt.show()
 
25
 
26
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  demo.launch()
 
1
+ import glob
2
  import gradio as gr
3
+ import numpy as np
4
  import torch
5
+ from torch.utils.data import DataLoader
6
+ from torchvision import transforms
7
+ import imageio.v3 as iio
8
+
9
  from UNET_perso import UNET
10
  import matplotlib.pyplot as plt
11
+ from src.medicalDataLoader import MedicalImageDataset
 
12
 
 
 
13
 
14
+ def disp_prediction(num, alpha=0.4):
 
 
 
 
15
 
16
+ lst_patients = glob.glob('./Data/val/Img/*.png')
17
+ patient = lst_patients[num]
18
 
19
+ model = UNET(in_channels=1, out_channels=4).to('cpu')
 
 
20
 
21
+ filepath = 'UNET_perso'
22
+ model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
23
+ model.eval()
24
 
25
+ transform = transforms.Compose([transforms.ToTensor()])
26
+ val_set = MedicalImageDataset('val',
27
+ './Data',
28
+ transform=transform,
29
+ mask_transform=transform,
30
+ equalize=False)
31
+
32
+ val_loader = DataLoader(val_set,
33
+ batch_size=1,
34
+ shuffle=True)
35
 
36
+ for _, (img, label, name) in enumerate(val_loader):
37
+ if name[0] == patient:
38
+ im = img
39
+
40
+ pred = model(im)
41
+
42
+ pred = pred.detach().numpy()
43
+ pred = pred.reshape(4,256,256)
44
+ pred_mask = np.argmax(pred, axis=0)
45
+
46
+ fig = plt.figure()
47
+ im = im.detach().numpy()
48
+ im = im.reshape(256,256)
49
+ plt.imshow(alpha*pred_mask + (1-alpha)*im)
50
+ plt.axis('off')
51
+
52
+ return fig
53
+
54
+ def disp_init(num):
55
+ lst_patients = glob.glob('./Data/val/Img/*.png')
56
+ patient = lst_patients[num]
57
+ print(patient)
58
+ im = iio.imread(patient)
59
+ fig = plt.figure()
60
+ plt.imshow(im)
61
+ plt.axis('off')
62
+ return fig
63
+
64
+ with gr.Blocks() as demo:
65
+ with gr.Row() as row1:
66
+ slide = gr.Slider(min=0, max=89, value=10, step=1)
67
+ with gr.Row() as row2:
68
+ slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
69
+ slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
70
+ with gr.Row() as row3:
71
+ gr.DataFrame()
72
+
73
  demo.launch()