|
import gradio as gr |
|
|
|
from transformers import AutoModel |
|
import ecg_plot |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
|
|
model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True) |
|
|
|
def predict(num_ecgs): |
|
prediction = (model(1)[0].t()/1000) |
|
|
|
|
|
lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0) |
|
lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0) |
|
lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0) |
|
lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0) |
|
all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0) |
|
all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])] |
|
|
|
ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12') |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
img = Image.open(buf) |
|
return img |
|
|
|
gr.Interface( |
|
predict, |
|
inputs=None, |
|
outputs="image", |
|
title="Generating ECGs", |
|
).launch() |
|
|