import io import gradio as gr #from transformers import pipeline from transformers import AutoModel import ecg_plot import matplotlib.pyplot as plt from PIL import Image import torch #pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog") model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True) def predict(): prediction = (model(1)[0].t()/1000) # to micro volte lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0) lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0) lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0) lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0) all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0) all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])] ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12') #ecg_plot.show() buf = io.BytesIO() plt.savefig(buf, format="png") img = Image.open(buf) return img gr.Interface( predict, inputs=None, outputs="image", title="Generating Fake ECGs", ).launch()