dalexanderch commited on
Commit
8b25912
·
1 Parent(s): 6fa8fad

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -45
app.py CHANGED
@@ -8,52 +8,12 @@ import numpy as np
8
  import torch
9
  from glycowork.glycan_data.loader import lib
10
 
11
- # Update lib
12
- equivalence_classes = [
13
- ["Glc", "Man", "Gal", "Gul", "Alt", "All", "Tal", "Ido" ],
14
- ["GlcNAc", "ManNAc", "GalNAc", "GulNAc", "AltNAc", "AllNAc", "TalNAc", "IdoNAc"],
15
- ["GlcN", "ManN", "GalN", "GulN", "AltN", "AllN", "TalN", "IdoN"],
16
- ["GlcA", "ManA", "GalA", "GulA", "AltA", "AllA", "TalA", "IdoA"],
17
- ["Qui", "Rha", "6dGul", "6dAlt", "6dTal", "Fuc"],
18
- ["QuiNAc", "RhaNAc", "6dAltNAc", "6dTalNAc", "FucNAc"],
19
- ["Oli", "Tyv", "Abe", "Par", "Dig", "Col"],
20
- ["Ara", "Lyx", "Xyl", "Rib"],
21
- ["Kdn", "Neu5Ac", "Neu5Gc", "Neu", "Sia"],
22
- ["Pse", "Leg", "Aci", "4eLeg"],
23
- ["Bac", "LDmanHep", "Kdo", "Dha", "DDmanHep", "MurNAc", "MurNGc", "Mur", "Api", "Fru", "Tag", "Sor", "Psi"]
24
- ]
25
-
26
- linkage_classes = [
27
- ["a1-2", "a1-z", "z1-2", "z1-z"],
28
- ["a1-3", "a1-z", "z1-3", "z1-z"],
29
- ["a1-4", "a1-z", "z1-4", "z1-z"],
30
- ["a1-6", "a1-z", "z1-6", "z1-z"],
31
- ["b1-2", "b1-z", "z1-2", "z1-z"],
32
- ["b1-3", "b1-z", "z1-3", "z1-z"],
33
- ["b1-4", "b1-z", "z1-4", "z1-z"],
34
- ["b1-6", "b1-z", "z1-6", "z1-z"],
35
- ["a2-3", "a2-z", "z2-3", "z2-z"],
36
- ["a2-6", "a2-z", "z2-6", "z2-z"],
37
- ["a2-8", "a2-z", "z2-8", "z2-z"]
38
- ]
39
-
40
- # Update lib
41
- print(len(lib))
42
- for equivalence_class in equivalence_classes:
43
- for target in equivalence_class:
44
- if target not in lib:
45
- lib.append(target)
46
- for linkage_class in linkage_classes:
47
- for target in linkage_class:
48
- if target not in lib:
49
- lib.append(target)
50
- print(len(lib))
51
 
52
  def fn(model, class_list):
53
  def f(glycan):
54
  glycan = [glycan]
55
  label = [0]
56
- data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1, libr=lib)))
57
  device = "cpu"
58
  if torch.cuda.is_available():
59
  device = "cuda:0"
@@ -63,9 +23,10 @@ def fn(model, class_list):
63
  x = x.to(device)
64
  edge_index = edge_index.to(device)
65
  batch = batch.to(device)
66
- pred = model(x,edge_index, batch).cpu().detach().numpy()
67
- pred = np.argmax(pred)
68
- pred = class_list[pred]
 
69
  return pred
70
  return f
71
 
@@ -79,7 +40,7 @@ f = fn(model, class_list)
79
  demo = gr.Interface(
80
  fn=f,
81
  inputs=[gr.Textbox(label="Glycan sequence")],
82
- outputs=[gr.Textbox(label="Predicted Class")],
83
  allow_flagging=False,
84
  title="SweetNet demo",
85
  examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",
 
8
  import torch
9
  from glycowork.glycan_data.loader import lib
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def fn(model, class_list):
13
  def f(glycan):
14
  glycan = [glycan]
15
  label = [0]
16
+ data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
17
  device = "cpu"
18
  if torch.cuda.is_available():
19
  device = "cuda:0"
 
23
  x = x.to(device)
24
  edge_index = edge_index.to(device)
25
  batch = batch.to(device)
26
+ pred = model(x,edge_index, batch).cpu().detach().numpy()[0]
27
+ pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
28
+ pred = [float(x) for x in pred]
29
+ pred = {class_list[i]:pred[i] for i in range(15)}
30
  return pred
31
  return f
32
 
 
40
  demo = gr.Interface(
41
  fn=f,
42
  inputs=[gr.Textbox(label="Glycan sequence")],
43
+ outputs=[gr.Label(num_top_classes=15, label="Class prediction")],
44
  allow_flagging=False,
45
  title="SweetNet demo",
46
  examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",