dalexanderch commited on
Commit
6506504
1 Parent(s): 294e24f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -9,10 +9,9 @@ from glycowork.ml.processing import dataset_to_dataloader
9
  import numpy as np
10
  import torch
11
  import torch.nn as nn
12
- from networkx.drawing.nx_agraph import write_dot
13
- # import pygraphviz as pgv
14
- # from glycowork.motif.graph import glycan_to_nxGraph
15
- # import networkx as nx
16
 
17
  class EnsembleModel(nn.Module):
18
  def __init__(self, models):
@@ -41,15 +40,12 @@ model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
41
 
42
  def fn(glycan, model):
43
  # Draw graph
44
- # graph = glycan_to_nxGraph(glycan)
45
- # node_labels = nx.get_node_attributes(graph, 'string_labels')
46
- # labels = {i:node_labels[i] for i in range(len(graph.nodes))}
47
- # graph = nx.relabel_nodes(graph, labels)
48
- # write_dot(graph, "graph.dot")
49
- # graph=pgv.AGraph("graph.dot")
50
- # graph.layout(prog='dot')
51
- # graph.draw("graph.png")
52
-
53
  # Perform inference
54
  if model == "No data augmentation":
55
  model_pred = model1
@@ -80,13 +76,13 @@ def fn(glycan, model):
80
  pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
81
  pred = [float(x) for x in pred]
82
  pred = {class_list[i]:pred[i] for i in range(15)}
83
- return pred
84
 
85
 
86
  demo = gr.Interface(
87
  fn=fn,
88
  inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
89
- outputs=[gr.Label(num_top_classes=15, label="Prediction")],
90
  allow_flagging=False,
91
  title="SweetNet demo",
92
  examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
 
9
  import numpy as np
10
  import torch
11
  import torch.nn as nn
12
+ from glycowork.motif.graph import glycan_to_nxGraph
13
+ import networkx as nx
14
+ import matplotlib.pyplot as plt
 
15
 
16
  class EnsembleModel(nn.Module):
17
  def __init__(self, models):
 
40
 
41
  def fn(glycan, model):
42
  # Draw graph
43
+ graph = glycan_to_nxGraph(glycan)
44
+ node_labels = nx.get_node_attributes(graph, 'string_labels')
45
+ labels = {i:node_labels[i] for i in range(len(graph.nodes))}
46
+ graph = nx.relabel_nodes(graph, labels)
47
+ nx.draw(graph, with_labels=True)
48
+ plt.savefig("graph.png")
 
 
 
49
  # Perform inference
50
  if model == "No data augmentation":
51
  model_pred = model1
 
76
  pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
77
  pred = [float(x) for x in pred]
78
  pred = {class_list[i]:pred[i] for i in range(15)}
79
+ return "graph.png", pred
80
 
81
 
82
  demo = gr.Interface(
83
  fn=fn,
84
  inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
85
+ outputs=[gr.Image(label="Glycan graph"), gr.Label(num_top_classes=15, label="Prediction")],
86
  allow_flagging=False,
87
  title="SweetNet demo",
88
  examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],