dalexanderch commited on
Commit
50edbe9
1 Parent(s): 85f8980

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -6,8 +6,26 @@ import gradio as gr
6
  from glycowork.ml.processing import dataset_to_dataloader
7
  import numpy as np
8
  import torch
 
9
 
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
12
  'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
13
 
 
6
  from glycowork.ml.processing import dataset_to_dataloader
7
  import numpy as np
8
  import torch
9
+ import torch.nn as nn
10
 
11
+ class EnsembleModel(nn.Module):
12
+ def __init__(self, models):
13
+ super().__init__()
14
+ self.models = models
15
+
16
+ def forward(self, data):
17
+ # Check if GPU available
18
+ device = "cpu"
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ # Prepare data
22
+ x = data.labels.to(device)
23
+ edge_index = data.edge_index.to(device)
24
+ batch = data.batch.to(device)
25
+ y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
26
+ y_pred = np.mean(y_pred,axis=0)[0]
27
+ return y_pred
28
+
29
  class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
30
  'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
31