DHEIVER commited on
Commit
f797a71
1 Parent(s): 481ba50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -35
app.py CHANGED
@@ -6,15 +6,11 @@ import torch
6
  import cv2
7
  import numpy as np
8
  from preprocess import unsharp_masking
9
- import glob
10
  import time
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- print(
15
- "torch: ", torch.__version__,
16
- )
17
-
18
  def ordenar_arquivos(img, modelo):
19
  ori = img.copy()
20
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
@@ -22,6 +18,7 @@ def ordenar_arquivos(img, modelo):
22
  img_out = preprocessamento(img, modelo)
23
  return img_out, h, w, img, ori
24
 
 
25
  def preprocessamento(img, modelo='SE-RegUNet 4GF'):
26
  img = cv2.resize(img, (512, 512))
27
  img = unsharp_masking(img).astype(np.uint8)
@@ -44,6 +41,7 @@ def preprocessamento(img, modelo='SE-RegUNet 4GF'):
44
  img_out = np.stack((image1,) * 3, axis=0)
45
  return img_out
46
 
 
47
  def processar_imagem_de_entrada(img, modelo, pipe):
48
  img = img.copy()
49
  pipe = pipe.to(device).eval()
@@ -65,7 +63,7 @@ def processar_imagem_de_entrada(img, modelo, pipe):
65
  img_out[logit, 0] = 255
66
  return spent, img_out
67
 
68
- # Load the models outside the function
69
  models = {
70
  'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
71
  'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
@@ -75,46 +73,48 @@ models = {
75
  'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
76
  }
77
 
78
- def processar_imagem_de_entrada_wrapper(img, modelo):
79
- model = models[modelo]
80
- spent, img_out = processar_imagem_de_entrada(img, modelo, model)
81
 
82
- def has_disease(img, model, min_area=500):
83
- """Checks if the angiogram has disease.
84
-
85
- Args:
86
- img: The input angiogram.
87
- model: The segmentation model.
88
- min_area: The minimum area for an anomaly to be considered a disease.
89
-
90
- Returns:
91
- True if the angiogram has disease, False otherwise.
92
- """
93
 
94
- # Segment the angiogram.
95
- mask = model.predict(img)
 
96
 
97
- # Detect anomalies in the mask.
98
- anomalies = detect_anomalies(img, mask)
99
 
100
- # Check if any of the anomalies are large enough to be considered a disease.
101
- for anomaly in anomalies:
102
- if anomaly[1] >= min_area:
103
- return True
104
 
105
- return False
 
 
 
 
106
 
107
- has_disease = has_disease(img_out)
 
 
 
 
 
 
 
108
 
109
  # Adicione a explicação à interface
110
- if has_disease:
111
- explanation = "A máquina detectou uma possível doença nos vasos sanguíneos. Ela olha para a parte destacada da imagem e calcula se há mais vasos do que o normal. Se for o caso, ela diz que há uma doença. Caso contrário, ela diz que não há doença."
112
  else:
113
  explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
114
 
115
- return spent, img_out, has_disease, explanation
116
-
117
 
 
118
  my_app = gr.Interface(
119
  fn=processar_imagem_de_entrada_wrapper,
120
  inputs=[
@@ -128,10 +128,11 @@ my_app = gr.Interface(
128
  gr.outputs.Label(label="Explicação"),
129
  ],
130
  title="Segmentação de Angiograma Coronariano",
131
- description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.",
132
  theme="default",
133
  layout="vertical",
134
  allow_flagging=False,
135
  )
136
 
 
137
  my_app.launch()
 
6
  import cv2
7
  import numpy as np
8
  from preprocess import unsharp_masking
 
9
  import time
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Função para ordenar e pré-processar a imagem de entrada
 
 
 
14
  def ordenar_arquivos(img, modelo):
15
  ori = img.copy()
16
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
 
18
  img_out = preprocessamento(img, modelo)
19
  return img_out, h, w, img, ori
20
 
21
+ # Função para pré-processar a imagem com base no modelo selecionado
22
  def preprocessamento(img, modelo='SE-RegUNet 4GF'):
23
  img = cv2.resize(img, (512, 512))
24
  img = unsharp_masking(img).astype(np.uint8)
 
41
  img_out = np.stack((image1,) * 3, axis=0)
42
  return img_out
43
 
44
+ # Função para processar a imagem de entrada
45
  def processar_imagem_de_entrada(img, modelo, pipe):
46
  img = img.copy()
47
  pipe = pipe.to(device).eval()
 
63
  img_out[logit, 0] = 255
64
  return spent, img_out
65
 
66
+ # Carregar modelos pré-treinados
67
  models = {
68
  'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
69
  'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
 
73
  'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
74
  }
75
 
76
+ from sklearn.cluster import KMeans
 
 
77
 
78
+ def has_disease(img, min_area=500):
79
+ # Aplicar K-Means na imagem segmentada ou máscara
80
+ # Certifique-se de fornecer a imagem ou máscara correta
81
+ # A imagem ou máscara deve ser preprocessada e segmentada antes desta etapa
 
 
 
 
 
 
 
82
 
83
+ kmeans = KMeans(n_clusters=2, random_state=0)
84
+ flattened_img = img.reshape((-1, 1)) # Transformar a imagem em uma matriz unidimensional
85
+ kmeans.fit(flattened_img)
86
 
87
+ # Rotular as regiões da imagem
88
+ labels = kmeans.labels_
89
 
90
+ # Calcular a área das regiões
91
+ area_0 = np.sum(labels == 0)
92
+ area_1 = np.sum(labels == 1)
 
93
 
94
+ # Verificar se a área da região com doença é maior que o limite mínimo
95
+ if area_1 >= min_area:
96
+ return True
97
+ else:
98
+ return False
99
 
100
+
101
+ # Função que encapsula o processamento e a verificação de doenças
102
+ def processar_imagem_de_entrada_wrapper(img, modelo):
103
+ model = models[modelo]
104
+ spent, img_out = processar_imagem_de_entrada(img, modelo, model)
105
+
106
+ # Chame a função has_disease para verificar se há doença
107
+ has_disease_flag = has_disease(img_out)
108
 
109
  # Adicione a explicação à interface
110
+ if has_disease_flag:
111
+ explanation = "A máquina detectou uma possível doença nos vasos sanguíneos."
112
  else:
113
  explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
114
 
115
+ return spent, img_out, has_disease_flag, explanation
 
116
 
117
+ # Criar a interface Gradio
118
  my_app = gr.Interface(
119
  fn=processar_imagem_de_entrada_wrapper,
120
  inputs=[
 
128
  gr.outputs.Label(label="Explicação"),
129
  ],
130
  title="Segmentação de Angiograma Coronariano",
131
+ description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados.",
132
  theme="default",
133
  layout="vertical",
134
  allow_flagging=False,
135
  )
136
 
137
+ # Iniciar a interface Gradio
138
  my_app.launch()