File size: 4,843 Bytes
67b1ebe
eb72eac
 
67b1ebe
 
 
 
eb72eac
67b1ebe
acd7746
67b1ebe
 
 
f797a71
eb72eac
 
67b1ebe
eb72eac
 
 
67b1ebe
f797a71
eb72eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67b1ebe
f797a71
eb72eac
 
 
 
 
 
3dce0f5
eb72eac
 
 
 
 
 
 
 
 
 
 
 
 
 
f797a71
eb72eac
 
 
 
 
 
 
 
67b1ebe
acd7746
 
 
9a5fe1d
acd7746
f797a71
a042e2b
f797a71
 
 
 
8749c79
9a5fe1d
a042e2b
 
 
 
 
 
 
f797a71
 
67b1ebe
d797764
 
a042e2b
 
 
eb72eac
f797a71
eb72eac
 
 
3265fe3
eb72eac
 
 
3265fe3
 
 
 
eb72eac
 
f797a71
eb72eac
 
 
 
 
f797a71
481ba50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import time
from sklearn.cluster import KMeans

device = "cuda" if torch.cuda.is_available() else "cpu"

# Função para ordenar e pré-processar a imagem de entrada
def ordenar_arquivos(img, modelo):
    ori = img.copy()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    h, w = img.shape
    img_out = preprocessamento(img, modelo)
    return img_out, h, w, img, ori

# Função para pré-processar a imagem com base no modelo selecionado
def preprocessamento(img, modelo='SE-RegUNet 4GF'):
    img = cv2.resize(img, (512, 512))
    img = unsharp_masking(img).astype(np.uint8)
    if modelo == 'AngioNet' or modelo == 'UNet3+':
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        img_out = np.expand_dims(img, axis=0)
    elif modelo == 'SE-RegUNet 4GF':
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8, 8))
        image1 = clahe1.apply(img)
        image2 = clahe2.apply(img)
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
        img_out = np.stack((img, image1, image2), axis=0)
    else:
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        image1 = clahe1.apply(img)
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        img_out = np.stack((image1,) * 3, axis=0)
    return img_out

# Função para processar a imagem de entrada
def processar_imagem_de_entrada(img, modelo, pipe):
    img = img.copy()
    pipe = pipe.to(device).eval()
    start = time.time()
    img, h, w, ori_gray, ori = ordenar_arquivos(img, modelo)
    img = torch.FloatTensor(img).unsqueeze(0).to(device)
    with torch.no_grad():
        if modelo == 'AngioNet':
            img = torch.cat([img, img], dim=0)
        logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
    spent = time.time() - start
    spent = f"{spent:.3f} segundos"

    if h != 512 or w != 512:
        logit = cv2.resize(logit, (h, w))

    logit = logit.astype(bool)
    img_out = ori.copy()
    img_out[logit, 0] = 255
    return spent, img_out

# Carregar modelos pré-treinados
models = {
    'SE-RegUNet 4GF': torch.jit.load('./model/SERegUNet4GF.pt'),
    'SE-RegUNet 16GF': torch.jit.load('./model/SERegUNet16GF.pt'),
    'AngioNet': torch.jit.load('./model/AngioNet.pt'),
    'EffUNet++ B5': torch.jit.load('./model/EffUNetppb5.pt'),
    'Reg-SA-UNet++': torch.jit.load('./model/RegSAUnetpp.pt'),
    'UNet3+': torch.jit.load('./model/UNet3plus.pt'),
}

def processar_imagem_de_entrada_wrapper(img, modelo):
    model = models[modelo]
    spent, img_out = processar_imagem_de_entrada(img, modelo, model)
    
    # Verificar se há doença usando K-Means
    kmeans = KMeans(n_clusters=2, random_state=0)
    flattened_img = img_out[:, :, 0].reshape((-1, 1))  # Use the intensity channel
    kmeans.fit(flattened_img)
    labels = kmeans.labels_
    area_0 = np.sum(labels == 0)
    area_1 = np.sum(labels == 1)
    has_disease_flag = area_1 >= 200
    
    # Formatar o indicador de doença como uma string
    if has_disease_flag:
        status_doenca = "Sim"
    else:
        status_doenca = "Não"
    
    # Adicionar a explicação com base no status de doença
    if has_disease_flag:
        explanation = "A máquina detectou uma possível doença nos vasos sanguíneos."
    else:
        explanation = "A máquina não detectou nenhuma doença nos vasos sanguíneos."
    
    # ... (resto do seu código, se houver mais)
    
    return spent, img_out, status_doenca, explanation

# Criar a interface Gradio
my_app = gr.Interface(
    fn=processar_imagem_de_entrada_wrapper,
    inputs=[
        gr.inputs.Image(label="Angiograma:", shape=(512, 512)),
        gr.inputs.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Modelo', default='SE-RegUNet 4GF'),
    ],
    outputs=[
        gr.outputs.Label(label="Tempo decorrido"),
        gr.outputs.Image(type="numpy", label="Imagem de Saída"),
        gr.outputs.Label(label="Possui Doença?"),
        gr.outputs.Label(label="Explicação"),
    ],
    title="Segmentação de Angiograma Coronariano",
    description="Esta aplicação segmenta angiogramas coronarianos usando modelos de segmentação pré-treinados.",
    theme="default",
    layout="vertical",
    allow_flagging=False,
)

# Iniciar a interface Gradio
my_app.launch()