File size: 1,181 Bytes
58b95fb
7bbfc34
66eabd8
3c8cf6c
b5953e8
 
 
58b95fb
3c8cf6c
 
 
b8a864b
b5953e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8cf6c
 
 
b5953e8
 
f89f767
a2c89cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import io
import gradio as gr
#from transformers import pipeline
from transformers import AutoModel
import ecg_plot
import matplotlib.pyplot as plt
from PIL import Image
import torch
#pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)

def predict():
    prediction = (model(1)[0].t()/1000) # to micro volte
    
    
    lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0)
    lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0)
    lead_aVL = (prediction[0] -  prediction[1]* 0.5).unsqueeze(dim=0)
    lead_aVF = (prediction[1] -  prediction[0]* 0.5).unsqueeze(dim=0)
    all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0)
    all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])]
    
    ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12')
    
    #ecg_plot.show()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    img = Image.open(buf)
    return img

gr.Interface(
    predict,
    inputs=None,
    outputs="image",
    title="Generating Fake ECGs",
).launch()