Tej3's picture
Updating app file
d0f8eba
raw
history blame
5.18 kB
import os
import shutil
import gradio as gr
import numpy as np
import wfdb
import torch
from wfdb.plot.plot import plot_wfdb
from wfdb.io.record import Record, rdrecord
from models.CNN import CNN, MMCNN_CAT
from models.RNN import MMRNN
from utils.helper_functions import predict
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from langdetect import detect
# edit this before Running
CWD = os.getcwd()
#CKPT paths
MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt"
MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt"
# Define clinical models and tokenizers
en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT'
ger_clin_bert = 'smanjil/German-MedBERT'
en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert)
en_model = AutoModel.from_pretrained(en_clin_bert)
g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert)
g_model = AutoModel.from_pretrained(ger_clin_bert)
def preprocess(data_file_path):
data = [wfdb.rdsamp(data_file_path)]
data = np.array([signal for signal, meta in data])
return data
def embed(notes):
if detect(notes) == 'en':
tokens = en_tokenizer(notes, return_tensors='pt')
outputs = en_model(**tokens)
else:
tokens = g_tokenizer(notes, return_tensors='pt')
outputs = g_model(**tokens)
embeddings = outputs.last_hidden_state
embedding = torch.mean(embeddings, dim=1).squeeze(0)
return embedding
# return torch.load(f'{"./data/embeddings/"}1.pt')
def plot_ecg(path):
record100 = rdrecord(path)
return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True)
def infer(model,data, notes):
embed_notes = embed(notes).unsqueeze(0)
data= torch.tensor(data)
if model == "CNN":
model = MMCNN_CAT()
checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint['model_state_dict'])
data = data.transpose(1,2).float()
elif model == "RNN":
model = MMRNN(device='cpu')
model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict'])
data = data.float()
model.eval()
outputs, predicted = predict(model, data, embed_notes, device='cpu')
outputs = torch.sigmoid(outputs)[0]
return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)}
def run(model_name, header_file, data_file, notes):
demo_dir = f"{CWD}/demo_data"
hdr_dirname, hdr_basename = os.path.split(header_file.name)
data_dirname, data_basename = os.path.split(data_file.name)
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}")
ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}")
os.remove(f"{demo_dir}/{data_basename}")
os.remove(f"{demo_dir}/{hdr_basename}")
output = infer(model_name, data, notes)
return output, ECG_graph
with gr.Blocks() as demo:
with gr.Row():
model = gr.Radio(['CNN', 'RNN'], label= "Select Model")
with gr.Row():
with gr.Column(scale=1):
header_file = gr.File(label = "header_file", file_types=[".hea"])
data_file = gr.File(label = "data_file", file_types=[".dat"])
notes = gr.Textbox(label = "Clinical Notes")
with gr.Column(scale=1):
output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False)
with gr.Row():
ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
with gr.Row():
predict_btn = gr.Button("Predict Class")
predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph])
with gr.Row():
gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\
[f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
[f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
[f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
],
inputs = [header_file, data_file, notes])
if __name__ == "__main__":
demo.launch()