Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import gradio as gr | |
import numpy as np | |
import wfdb | |
import torch | |
from wfdb.plot.plot import plot_wfdb | |
from wfdb.io.record import Record, rdrecord | |
from models.CNN import CNN, MMCNN_CAT | |
from models.RNN import MMRNN | |
from utils.helper_functions import predict | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, AutoModel | |
from langdetect import detect | |
# edit this before Running | |
CWD = os.getcwd() | |
#CKPT paths | |
MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt" | |
MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt" | |
# Define clinical models and tokenizers | |
en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT' | |
ger_clin_bert = 'smanjil/German-MedBERT' | |
en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert) | |
en_model = AutoModel.from_pretrained(en_clin_bert) | |
g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert) | |
g_model = AutoModel.from_pretrained(ger_clin_bert) | |
def preprocess(data_file_path): | |
data = [wfdb.rdsamp(data_file_path)] | |
data = np.array([signal for signal, meta in data]) | |
return data | |
def embed(notes): | |
if detect(notes) == 'en': | |
tokens = en_tokenizer(notes, return_tensors='pt') | |
outputs = en_model(**tokens) | |
else: | |
tokens = g_tokenizer(notes, return_tensors='pt') | |
outputs = g_model(**tokens) | |
embeddings = outputs.last_hidden_state | |
embedding = torch.mean(embeddings, dim=1).squeeze(0) | |
return embedding | |
# return torch.load(f'{"./data/embeddings/"}1.pt') | |
def plot_ecg(path): | |
record100 = rdrecord(path) | |
return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True) | |
def infer(model,data, notes): | |
embed_notes = embed(notes).unsqueeze(0) | |
data= torch.tensor(data) | |
if model == "CNN": | |
model = MMCNN_CAT() | |
checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu") | |
model.load_state_dict(checkpoint['model_state_dict']) | |
data = data.transpose(1,2).float() | |
elif model == "RNN": | |
model = MMRNN(device='cpu') | |
model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict']) | |
data = data.float() | |
model.eval() | |
outputs, predicted = predict(model, data, embed_notes, device='cpu') | |
outputs = torch.sigmoid(outputs)[0] | |
return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)} | |
def run(model_name, header_file, data_file, notes): | |
demo_dir = f"{CWD}/demo_data" | |
hdr_dirname, hdr_basename = os.path.split(header_file.name) | |
data_dirname, data_basename = os.path.split(data_file.name) | |
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}") | |
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}") | |
data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}") | |
ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}") | |
os.remove(f"{demo_dir}/{data_basename}") | |
os.remove(f"{demo_dir}/{hdr_basename}") | |
output = infer(model_name, data, notes) | |
return output, ECG_graph | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
model = gr.Radio(['CNN', 'RNN'], label= "Select Model") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
header_file = gr.File(label = "header_file", file_types=[".hea"]) | |
data_file = gr.File(label = "data_file", file_types=[".dat"]) | |
notes = gr.Textbox(label = "Clinical Notes") | |
with gr.Column(scale=1): | |
output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False) | |
with gr.Row(): | |
ecg_graph = gr.Plot(label = "ECG Signal Visualisation") | |
with gr.Row(): | |
predict_btn = gr.Button("Predict Class") | |
predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph]) | |
with gr.Row(): | |
gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\ | |
[f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \ | |
[f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\ | |
[f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\ | |
], | |
inputs = [header_file, data_file, notes]) | |
if __name__ == "__main__": | |
demo.launch() |