dalexanderch commited on
Commit
840fdaa
1 Parent(s): 47aa6b1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -1,11 +1,18 @@
1
  import os
2
  os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
3
  os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
 
 
 
4
  import gradio as gr
5
  from glycowork.ml.processing import dataset_to_dataloader
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
 
 
 
 
9
 
10
  class EnsembleModel(nn.Module):
11
  def __init__(self, models):
@@ -33,6 +40,17 @@ model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
33
  model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
34
 
35
  def fn(glycan, model):
 
 
 
 
 
 
 
 
 
 
 
36
  if model == "No data augmentation":
37
  model_pred = model1
38
  model_pred.eval()
@@ -62,13 +80,13 @@ def fn(glycan, model):
62
  pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
63
  pred = [float(x) for x in pred]
64
  pred = {class_list[i]:pred[i] for i in range(15)}
65
- return pred
66
 
67
 
68
  demo = gr.Interface(
69
  fn=fn,
70
  inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
71
- outputs=[gr.Label(num_top_classes=15, label="Prediction")],
72
  allow_flagging=False,
73
  title="SweetNet demo",
74
  examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
 
1
  import os
2
  os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
3
  os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
4
+ os.system("apt-get install -y graphviz-dev")
5
+ os.system("pip install pygraphviz")
6
+
7
  import gradio as gr
8
  from glycowork.ml.processing import dataset_to_dataloader
9
  import numpy as np
10
  import torch
11
  import torch.nn as nn
12
+ from networkx.drawing.nx_agraph import write_dot
13
+ import pygraphviz as pgv
14
+ from glycowork.motif.graph import glycan_to_nxGraph
15
+ import networkx as nx
16
 
17
  class EnsembleModel(nn.Module):
18
  def __init__(self, models):
 
40
  model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
41
 
42
  def fn(glycan, model):
43
+ # Draw graph
44
+ graph = glycan_to_nxGraph(glycan)
45
+ node_labels = nx.get_node_attributes(graph, 'string_labels')
46
+ labels = {i:node_labels[i] for i in range(len(graph.nodes))}
47
+ graph = nx.relabel_nodes(graph, labels)
48
+ write_dot(graph, "graph.dot")
49
+ graph=pgv.AGraph("graph.dot")
50
+ graph.layout(prog='dot')
51
+ graph.draw("graph.png")
52
+
53
+ # Perform inference
54
  if model == "No data augmentation":
55
  model_pred = model1
56
  model_pred.eval()
 
80
  pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
81
  pred = [float(x) for x in pred]
82
  pred = {class_list[i]:pred[i] for i in range(15)}
83
+ return pred, "graph.png"
84
 
85
 
86
  demo = gr.Interface(
87
  fn=fn,
88
  inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
89
+ outputs=[gr.Label(num_top_classes=15, label="Prediction"), gr.Image(label="Graph visualization")],
90
  allow_flagging=False,
91
  title="SweetNet demo",
92
  examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],