DHEIVER commited on
Commit
3265fe3
1 Parent(s): 3dce0f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -8,7 +8,6 @@ import time
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  def preprocess_image(img, model_name):
11
- # Preprocess the input image based on the selected model
12
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
13
  img = cv2.resize(img, (512, 512))
14
  img = cv2.GaussianBlur(img, (0, 0), 1.0)
@@ -23,7 +22,7 @@ def process_image(img, model):
23
  img = np.expand_dims(img, axis=0)
24
  img_tensor = torch.FloatTensor(img).unsqueeze(0).to(device)
25
  with torch.no_grad():
26
- logit = torch.softmax(model.forward(img_tensor), dim=1).detach().cpu().numpy()[0, 0]
27
  return (logit > 0.5).astype(np.uint8)
28
 
29
  def has_disease(segmented_img):
@@ -39,10 +38,12 @@ def explanation(has_disease_flag):
39
  explanation_text = "O angiograma não tem doença."
40
  return explanation_text
41
 
42
- def process_input_image(input_img, model_name, model):
43
  start_time = time.time()
44
 
45
  preprocessed_img = preprocess_image(input_img, model_name)
 
 
46
  processed_img = process_image(preprocessed_img, model)
47
  disease_flag = has_disease(processed_img)
48
  explanation_text = explanation(disease_flag)
@@ -51,35 +52,37 @@ def process_input_image(input_img, model_name, model):
51
  return elapsed_time, processed_img, disease_flag, explanation_text
52
 
53
  def main():
54
- # Load models
55
  model_paths = {
56
  'SE-RegUNet 4GF': './model/SERegUNet4GF.pt',
57
- # ... (other model paths here)
58
  }
59
- models = {model_name: load_model(model_path) for model_name, model_path in model_paths.items()}
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Create Gradio interface
62
  gr_interface = gr.Interface(
63
  fn=process_input_image,
64
- inputs=[
65
- gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
66
- gr.inputs.Dropdown(list(model_paths.keys()), label='Modelo', default='SE-RegUNet 4GF'),
67
- ],
68
- outputs=[
69
- gr.outputs.Label(label="Tempo decorrido"),
70
- gr.outputs.Image(type="numpy", label="Imagem de Saída"),
71
- gr.outputs.Label(label="Possui Doença?"),
72
- gr.outputs.Label(label="Explicação"),
73
- ],
74
  title="Segmentação de Angiograma Coronariano",
75
- description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.",
76
  theme="default",
77
- layout="vertical",
78
- allow_flagging=False,
79
  )
80
 
81
  # Launch Gradio interface
82
  gr_interface.launch()
83
 
84
  if __name__ == "__main__":
85
- main()
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  def preprocess_image(img, model_name):
 
11
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
12
  img = cv2.resize(img, (512, 512))
13
  img = cv2.GaussianBlur(img, (0, 0), 1.0)
 
22
  img = np.expand_dims(img, axis=0)
23
  img_tensor = torch.FloatTensor(img).unsqueeze(0).to(device)
24
  with torch.no_grad():
25
+ logit = torch.softmax(model(img_tensor), dim=1).detach().cpu().numpy()[0, 0]
26
  return (logit > 0.5).astype(np.uint8)
27
 
28
  def has_disease(segmented_img):
 
38
  explanation_text = "O angiograma não tem doença."
39
  return explanation_text
40
 
41
+ def process_input_image(input_img, model_name):
42
  start_time = time.time()
43
 
44
  preprocessed_img = preprocess_image(input_img, model_name)
45
+ model_path = model_paths[model_name] # Assume model_paths is defined
46
+ model = load_model(model_path)
47
  processed_img = process_image(preprocessed_img, model)
48
  disease_flag = has_disease(processed_img)
49
  explanation_text = explanation(disease_flag)
 
52
  return elapsed_time, processed_img, disease_flag, explanation_text
53
 
54
  def main():
55
+ # Define model paths
56
  model_paths = {
57
  'SE-RegUNet 4GF': './model/SERegUNet4GF.pt',
58
+ # ... other model paths
59
  }
60
+
61
+ # Define Gradio components
62
+ inputs = [
63
+ gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
64
+ gr.inputs.Dropdown(list(model_paths.keys()), label='Modelo', default='SE-RegUNet 4GF'),
65
+ ]
66
+
67
+ outputs = [
68
+ gr.outputs.Label(label="Tempo decorrido"),
69
+ gr.outputs.Image(type="numpy", label="Imagem de Saída"),
70
+ gr.outputs.Label(label="Possui Doença?"),
71
+ gr.outputs.Label(label="Explicação"),
72
+ ]
73
 
74
  # Create Gradio interface
75
  gr_interface = gr.Interface(
76
  fn=process_input_image,
77
+ inputs=inputs,
78
+ outputs=outputs,
 
 
 
 
 
 
 
 
79
  title="Segmentação de Angiograma Coronariano",
80
+ description="Esta aplicação segmenta angiogramas coronarianos...",
81
  theme="default",
 
 
82
  )
83
 
84
  # Launch Gradio interface
85
  gr_interface.launch()
86
 
87
  if __name__ == "__main__":
88
+ main()