File size: 4,720 Bytes
f46cff5
 
 
 
 
 
 
b31edfa
3eb5cc3
 
6d85093
f46cff5
 
3eb5cc3
 
 
 
 
 
 
f46cff5
 
3eb5cc3
f46cff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eb5cc3
 
 
f46cff5
3eb5cc3
f46cff5
 
 
 
 
 
 
 
 
 
 
 
3eb5cc3
 
 
f46cff5
3eb5cc3
 
 
 
 
 
 
 
 
 
 
 
d37da5c
2d3b13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37da5c
5a4384b
0b2c2fa
 
 
 
 
 
 
5a4384b
 
eb3ca77
f46cff5
3eb5cc3
f46cff5
3eb5cc3
 
14ef73d
 
3eb5cc3
 
d37da5c
14ef73d
 
3eb5cc3
14ef73d
3eb5cc3
 
14ef73d
 
3eb5cc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import glob
import time

device = "cuda" if torch.cuda.is_available() else "cpu"

print(
    "torch: ", torch.__version__,
)

def ordenar_arquivos(img, modelo):
    ori = img.copy()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    h, w = img.shape
    img_out = preprocessamento(img, modelo)
    return img_out, h, w, img, ori

def preprocessamento(img, modelo='SE-RegUNet 4GF'):
    img = cv2.resize(img, (512, 512))
    img = unsharp_masking(img).astype(np.uint8)
    if modelo == 'AngioNet' or modelo == 'UNet3+':
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        img_out = np.expand_dims(img, axis=0)
    elif modelo == 'SE-RegUNet 4GF':
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
        image1 = clahe1.apply(img)
        image2 = clahe2.apply(img)
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
        img_out = np.stack((img, image1, image2), axis=0)
    else:
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        image1 = clahe1.apply(img)
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        img_out = np.stack((image1,) * 3, axis=0)
    return img_out

def processar_imagem_de_entrada(img, modelo, pipe):
    img = img.copy()
    pipe = pipe.to(device).eval()
    start = time.time()
    img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo)
    img = torch.FloatTensor(img).unsqueeze(0).to(device)
    with torch.no_grad():
        if modelo == 'AngioNet':
            img = torch.cat([img, img], dim=0)
        logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
    spent = time.time() - start
    spent = f"{spent:.3f} segundos"

    if h != 512 or w != 512:
        logit = cv2.resize(logit, (h, w))

    logit = logit.astype(bool)
    img_out = ori.copy()
    img_out[logit, 0] = 255
    return spent, img_out

# Load the models outside the function
models = {
    'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
    'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
    'AngioNet': torch.jit.load('./model/AngioNet.pt'),
    'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
    'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
    'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
}

def processar_imagem_de_entrada_wrapper(img, modelo):
    model = models[modelo]
    spent, img_out = processar_imagem_de_entrada(img, modelo, model)

    # Define the function `has_disease`
    def has_disease(img_out):
        """
        Checks if the angiogram has disease based on the segmentation.
        Args:
            img_out: The segmented angiogram.
        Returns:
            True if the angiogram has disease, False otherwise.
        """
        percentage_of_vessels = np.sum(img_out) / (img_out.shape[0] * img_out.shape[1])
        if percentage_of_vessels > 0.5:
            return True
        else:
            return False

    has_disease = has_disease(img_out)
    loc_doenca = np.where(img_out == 1)

    if has_disease:
        print("A doença está localizada nas seguintes coordenadas:")
        print(loc_doenca)
    else:
        print("Não há doença no angiograma.")

    return spent, img_out, has_disease, loc_doenca
    

my_app = gr.Interface(
    fn=processar_imagem_de_entrada_wrapper,
    inputs=[
        gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
        gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
    ],
    outputs=[
        gr.outputs.Label(label="Tempo decorrido"),
        gr.outputs.Image(type="numpy", label="Imagem de Saída"),
        gr.outputs.Label(label="Possui Doença?"),
    ],
    title="Segmentação de Angiograma Coronariano",
    description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados. Faça o upload de uma imagem de angiograma e selecione um modelo para visualizar o resultado da segmentação.\n\nSelecione uma imagem de angiograma coronariano e um modelo de segmentação no painel à esquerda.",
    theme="default",
    layout="vertical",
    allow_flagging=False,
)

my_app.launch()