DHEIVER commited on
Commit
3dce0f5
1 Parent(s): 8a22cef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -114
app.py CHANGED
@@ -1,131 +1,85 @@
1
  import os
2
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
-
4
  import gradio as gr
5
  import torch
6
  import cv2
7
  import numpy as np
8
- from preprocess import unsharp_masking
9
- import glob
10
  import time
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- print(
15
- "torch: ", torch.__version__,
16
- )
17
-
18
- def ordenar_arquivos(img, modelo):
19
- ori = img.copy()
20
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
21
- h, w = img.shape
22
- img_out = preprocessamento(img, modelo)
23
- return img_out, h, w, img, ori
24
-
25
- def preprocessamento(img, modelo='SE-RegUNet 4GF'):
26
  img = cv2.resize(img, (512, 512))
27
- img = unsharp_masking(img).astype(np.uint8)
28
- if modelo == 'AngioNet' or modelo == 'UNet3+':
29
- img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
30
- img_out = np.expand_dims(img, axis=0)
31
- elif modelo == 'SE-RegUNet 4GF':
32
- clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
33
- clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
34
- image1 = clahe1.apply(img)
35
- image2 = clahe2.apply(img)
36
- img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
37
- image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
38
- image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
39
- img_out = np.stack((img, image1, image2), axis=0)
40
- else:
41
- clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
42
- image1 = clahe1.apply(img)
43
- image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
44
- img_out = np.stack((image1,) * 3, axis=0)
45
- return img_out
46
-
47
- def processar_imagem_de_entrada(img, modelo, pipe):
48
- img = img.copy()
49
- pipe = pipe.to(device).eval()
50
- start = time.time()
51
- img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo)
52
- img = torch.FloatTensor(img).unsqueeze(0).to(device)
53
- with torch.no_grad():
54
- if modelo == 'AngioNet':
55
- img = torch.cat([img, img], dim=0)
56
- logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
57
- spent = time.time() - start
58
- spent = f"{spent:.3f} segundos"
59
-
60
- if h != 512 or w != 512:
61
- logit = cv2.resize(logit, (h, w))
62
-
63
- logit = logit.astype(bool)
64
- img_out = ori.copy()
65
- img_out[logit, 0] = 255
66
- return spent, img_out
67
 
68
- # Load the models outside the function
69
- models = {
70
- 'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
71
- 'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
72
- 'AngioNet': torch.jit.load('./model/AngioNet.pt'),
73
- 'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
74
- 'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
75
- 'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
76
- }
77
 
78
- def processar_imagem_de_entrada_wrapper(img, modelo):
79
- model = models[modelo]
80
- spent, img_out = processar_imagem_de_entrada(img, modelo, model)
81
-
82
- # Define the function `has_disease`
83
- def has_disease(img_out):
84
- """
85
- Checks if the angiogram has disease based on the segmentation.
86
- Args:
87
- img_out: The segmented angiogram.
88
- Returns:
89
- True if the angiogram has disease, False otherwise.
90
- """
91
- percentage_of_vessels = np.sum(img_out) / (img_out.shape[0] * img_out.shape[1])
92
- if percentage_of_vessels > 0.5:
93
- return True
94
- else:
95
- return False
96
 
97
- has_disease = has_disease(img_out)
 
 
98
 
99
- # Add the explanation to the interface
100
- if has_disease:
101
- explanation = (
102
- f"* **True:** O angiograma tem doença. Isso é determinado pela função `has_disease`, que calcula o percentual de vasos no angiograma segmentado. Se o percentual de vasos for maior que 50%, a função retorna \"true\", indicando que o angiograma tem doença. Caso contrário, a função retorna \"false\", indicando que o angiograma não tem doença."
103
  )
104
  else:
105
- explanation = (
106
- f"* **False:** O angiograma não tem doença."
107
- )
108
-
109
- return spent, img_out, has_disease, explanation
110
-
111
-
112
- my_app = gr.Interface(
113
- fn=processar_imagem_de_entrada_wrapper,
114
- inputs=[
115
- gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
116
- gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
117
- ],
118
- outputs=[
119
- gr.outputs.Label(label="Tempo decorrido"),
120
- gr.outputs.Image(type="numpy", label="Imagem de Saída"),
121
- gr.outputs.Label(label="Possui Doença?"),
122
- gr.outputs.Label(label="Explicação"),
123
- ],
124
- title="Segmentação de Angiograma Coronariano",
125
- description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.",
126
- theme="default",
127
- layout="vertical",
128
- allow_flagging=False,
129
- )
130
-
131
- my_app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
2
  import gradio as gr
3
  import torch
4
  import cv2
5
  import numpy as np
 
 
6
  import time
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ def preprocess_image(img, model_name):
11
+ # Preprocess the input image based on the selected model
 
 
 
 
12
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
 
 
 
 
 
13
  img = cv2.resize(img, (512, 512))
14
+ img = cv2.GaussianBlur(img, (0, 0), 1.0)
15
+ img = img.astype(np.float32) / 255.0
16
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def load_model(model_path):
19
+ model = torch.jit.load(model_path)
20
+ return model
 
 
 
 
 
 
21
 
22
+ def process_image(img, model):
23
+ img = np.expand_dims(img, axis=0)
24
+ img_tensor = torch.FloatTensor(img).unsqueeze(0).to(device)
25
+ with torch.no_grad():
26
+ logit = torch.softmax(model.forward(img_tensor), dim=1).detach().cpu().numpy()[0, 0]
27
+ return (logit > 0.5).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def has_disease(segmented_img):
30
+ percentage_of_vessels = np.sum(segmented_img) / (segmented_img.shape[0] * segmented_img.shape[1])
31
+ return percentage_of_vessels > 0.5
32
 
33
+ def explanation(has_disease_flag):
34
+ if has_disease_flag:
35
+ explanation_text = (
36
+ "O angiograma tem doença. Isso é determinado pela função `has_disease`, que calcula o percentual de vasos no angiograma segmentado. Se o percentual de vasos for maior que 50%, a função retorna \"true\", indicando que o angiograma tem doença. Caso contrário, a função retorna \"false\", indicando que o angiograma não tem doença."
37
  )
38
  else:
39
+ explanation_text = "O angiograma não tem doença."
40
+ return explanation_text
41
+
42
+ def process_input_image(input_img, model_name, model):
43
+ start_time = time.time()
44
+
45
+ preprocessed_img = preprocess_image(input_img, model_name)
46
+ processed_img = process_image(preprocessed_img, model)
47
+ disease_flag = has_disease(processed_img)
48
+ explanation_text = explanation(disease_flag)
49
+
50
+ elapsed_time = f"{time.time() - start_time:.3f} segundos"
51
+ return elapsed_time, processed_img, disease_flag, explanation_text
52
+
53
+ def main():
54
+ # Load models
55
+ model_paths = {
56
+ 'SE-RegUNet 4GF': './model/SERegUNet4GF.pt',
57
+ # ... (other model paths here)
58
+ }
59
+ models = {model_name: load_model(model_path) for model_name, model_path in model_paths.items()}
60
+
61
+ # Create Gradio interface
62
+ gr_interface = gr.Interface(
63
+ fn=process_input_image,
64
+ inputs=[
65
+ gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
66
+ gr.inputs.Dropdown(list(model_paths.keys()), label='Modelo', default='SE-RegUNet 4GF'),
67
+ ],
68
+ outputs=[
69
+ gr.outputs.Label(label="Tempo decorrido"),
70
+ gr.outputs.Image(type="numpy", label="Imagem de Saída"),
71
+ gr.outputs.Label(label="Possui Doença?"),
72
+ gr.outputs.Label(label="Explicação"),
73
+ ],
74
+ title="Segmentação de Angiograma Coronariano",
75
+ description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.",
76
+ theme="default",
77
+ layout="vertical",
78
+ allow_flagging=False,
79
+ )
80
+
81
+ # Launch Gradio interface
82
+ gr_interface.launch()
83
+
84
+ if __name__ == "__main__":
85
+ main()