Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ import time
|
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
|
10 |
def preprocess_image(img, model_name):
|
11 |
-
# Preprocess the input image based on the selected model
|
12 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
13 |
img = cv2.resize(img, (512, 512))
|
14 |
img = cv2.GaussianBlur(img, (0, 0), 1.0)
|
@@ -23,7 +22,7 @@ def process_image(img, model):
|
|
23 |
img = np.expand_dims(img, axis=0)
|
24 |
img_tensor = torch.FloatTensor(img).unsqueeze(0).to(device)
|
25 |
with torch.no_grad():
|
26 |
-
logit = torch.softmax(model
|
27 |
return (logit > 0.5).astype(np.uint8)
|
28 |
|
29 |
def has_disease(segmented_img):
|
@@ -39,10 +38,12 @@ def explanation(has_disease_flag):
|
|
39 |
explanation_text = "O angiograma não tem doença."
|
40 |
return explanation_text
|
41 |
|
42 |
-
def process_input_image(input_img, model_name
|
43 |
start_time = time.time()
|
44 |
|
45 |
preprocessed_img = preprocess_image(input_img, model_name)
|
|
|
|
|
46 |
processed_img = process_image(preprocessed_img, model)
|
47 |
disease_flag = has_disease(processed_img)
|
48 |
explanation_text = explanation(disease_flag)
|
@@ -51,35 +52,37 @@ def process_input_image(input_img, model_name, model):
|
|
51 |
return elapsed_time, processed_img, disease_flag, explanation_text
|
52 |
|
53 |
def main():
|
54 |
-
#
|
55 |
model_paths = {
|
56 |
'SE-RegUNet 4GF': './model/SERegUNet4GF.pt',
|
57 |
-
# ...
|
58 |
}
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
# Create Gradio interface
|
62 |
gr_interface = gr.Interface(
|
63 |
fn=process_input_image,
|
64 |
-
inputs=
|
65 |
-
|
66 |
-
gr.inputs.Dropdown(list(model_paths.keys()), label='Modelo', default='SE-RegUNet 4GF'),
|
67 |
-
],
|
68 |
-
outputs=[
|
69 |
-
gr.outputs.Label(label="Tempo decorrido"),
|
70 |
-
gr.outputs.Image(type="numpy", label="Imagem de Saída"),
|
71 |
-
gr.outputs.Label(label="Possui Doença?"),
|
72 |
-
gr.outputs.Label(label="Explicação"),
|
73 |
-
],
|
74 |
title="Segmentação de Angiograma Coronariano",
|
75 |
-
description="Esta aplicação segmenta angiogramas coronarianos
|
76 |
theme="default",
|
77 |
-
layout="vertical",
|
78 |
-
allow_flagging=False,
|
79 |
)
|
80 |
|
81 |
# Launch Gradio interface
|
82 |
gr_interface.launch()
|
83 |
|
84 |
if __name__ == "__main__":
|
85 |
-
main()
|
|
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
|
10 |
def preprocess_image(img, model_name):
|
|
|
11 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
12 |
img = cv2.resize(img, (512, 512))
|
13 |
img = cv2.GaussianBlur(img, (0, 0), 1.0)
|
|
|
22 |
img = np.expand_dims(img, axis=0)
|
23 |
img_tensor = torch.FloatTensor(img).unsqueeze(0).to(device)
|
24 |
with torch.no_grad():
|
25 |
+
logit = torch.softmax(model(img_tensor), dim=1).detach().cpu().numpy()[0, 0]
|
26 |
return (logit > 0.5).astype(np.uint8)
|
27 |
|
28 |
def has_disease(segmented_img):
|
|
|
38 |
explanation_text = "O angiograma não tem doença."
|
39 |
return explanation_text
|
40 |
|
41 |
+
def process_input_image(input_img, model_name):
|
42 |
start_time = time.time()
|
43 |
|
44 |
preprocessed_img = preprocess_image(input_img, model_name)
|
45 |
+
model_path = model_paths[model_name] # Assume model_paths is defined
|
46 |
+
model = load_model(model_path)
|
47 |
processed_img = process_image(preprocessed_img, model)
|
48 |
disease_flag = has_disease(processed_img)
|
49 |
explanation_text = explanation(disease_flag)
|
|
|
52 |
return elapsed_time, processed_img, disease_flag, explanation_text
|
53 |
|
54 |
def main():
|
55 |
+
# Define model paths
|
56 |
model_paths = {
|
57 |
'SE-RegUNet 4GF': './model/SERegUNet4GF.pt',
|
58 |
+
# ... other model paths
|
59 |
}
|
60 |
+
|
61 |
+
# Define Gradio components
|
62 |
+
inputs = [
|
63 |
+
gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
|
64 |
+
gr.inputs.Dropdown(list(model_paths.keys()), label='Modelo', default='SE-RegUNet 4GF'),
|
65 |
+
]
|
66 |
+
|
67 |
+
outputs = [
|
68 |
+
gr.outputs.Label(label="Tempo decorrido"),
|
69 |
+
gr.outputs.Image(type="numpy", label="Imagem de Saída"),
|
70 |
+
gr.outputs.Label(label="Possui Doença?"),
|
71 |
+
gr.outputs.Label(label="Explicação"),
|
72 |
+
]
|
73 |
|
74 |
# Create Gradio interface
|
75 |
gr_interface = gr.Interface(
|
76 |
fn=process_input_image,
|
77 |
+
inputs=inputs,
|
78 |
+
outputs=outputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
title="Segmentação de Angiograma Coronariano",
|
80 |
+
description="Esta aplicação segmenta angiogramas coronarianos...",
|
81 |
theme="default",
|
|
|
|
|
82 |
)
|
83 |
|
84 |
# Launch Gradio interface
|
85 |
gr_interface.launch()
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
+
main()
|