vlbthambawita commited on
Commit
b5953e8
1 Parent(s): 66eabd8
Files changed (2) hide show
  1. app.py +23 -6
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,17 +1,34 @@
1
  import gradio as gr
2
  #from transformers import pipeline
3
  from transformers import AutoModel
4
-
 
 
5
  #pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
6
  model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)
7
 
8
  def predict(num_ecgs):
9
- predictions = model(int(num_ecgs))
10
- return {"ecgs": predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  gr.Interface(
13
  predict,
14
- inputs="text",
15
- outputs="text",
16
  title="Generating ECGs",
17
- ).launch()
 
1
  import gradio as gr
2
  #from transformers import pipeline
3
  from transformers import AutoModel
4
+ import ecg_plot
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
  #pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
8
  model = AutoModel.from_pretrained("deepsynthbody/deepfake_ecg", trust_remote_code=True)
9
 
10
  def predict(num_ecgs):
11
+ prediction = (model(1)[0].t()/1000) # to micro volte
12
+
13
+
14
+ lead_III = (prediction[1] - prediction[0]).unsqueeze(dim=0)
15
+ lead_aVR = ((prediction[0] + prediction[1])*(-0.5)).unsqueeze(dim=0)
16
+ lead_aVL = (prediction[0] - prediction[1]* 0.5).unsqueeze(dim=0)
17
+ lead_aVF = (prediction[1] - prediction[0]* 0.5).unsqueeze(dim=0)
18
+ all = torch.cat((prediction, lead_III, lead_aVR, lead_aVL, lead_aVF), dim=0)
19
+ all_corrected = all[torch.tensor([0,1,8, 9, 10, 11, 2,3,4,5,6,7])]
20
+
21
+ ecg_plot.plot(all_corrected, sample_rate = 500, title = 'ECG 12')
22
+
23
+ #ecg_plot.show()
24
+ buf = io.BytesIO()
25
+ plt.savefig(buf, format="png")
26
+ img = Image.open(buf)
27
+ return img
28
 
29
  gr.Interface(
30
  predict,
31
+ inputs=None,
32
+ outputs="image",
33
  title="Generating ECGs",
34
+ ).launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  transformers
2
- torch
 
 
 
 
1
  transformers
2
+ torch
3
+ ecg-plot
4
+ matplotlib
5
+ PIL