Spaces:
Runtime error
Runtime error
import os | |
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu") | |
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html") | |
import gradio as gr | |
from glycowork.ml.processing import dataset_to_dataloader | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from glycowork.motif.graph import glycan_to_nxGraph | |
import networkx as nx | |
import pydot | |
# import pygraphviz as pgv | |
class EnsembleModel(nn.Module): | |
def __init__(self, models): | |
super().__init__() | |
self.models = models | |
def forward(self, data): | |
# Check if GPU available | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
# Prepare data | |
x = data.labels.to(device) | |
edge_index = data.edge_index.to(device) | |
batch = data.batch.to(device) | |
y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models] | |
y_pred = np.mean(y_pred,axis=0)[0] | |
return y_pred | |
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', | |
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria'] | |
model1 = torch.load("model1.pt", map_location=torch.device('cpu')) | |
model2 = torch.load("model2.pt", map_location=torch.device('cpu')) | |
model3 = torch.load("model3.pt", map_location=torch.device('cpu')) | |
model4 = torch.load("model4.pt", map_location=torch.device('cpu')) | |
model5 = torch.load("model5.pt", map_location=torch.device('cpu')) | |
model6 = torch.load("model6.pt", map_location=torch.device('cpu')) | |
model7 = torch.load("model7.pt", map_location=torch.device('cpu')) | |
def fn(glycan, model): | |
# Draw graph | |
graph = glycan_to_nxGraph(glycan) | |
node_labels = nx.get_node_attributes(graph, 'string_labels') | |
labels = {i:node_labels[i] for i in range(len(graph.nodes))} | |
graph = nx.relabel_nodes(graph, labels) | |
graph = nx.drawing.nx_pydot.to_pydot(graph) | |
graph.set_prog("dot") | |
graph.write_png("graph.png") | |
# write_dot(graph, "graph.dot") | |
# graph=pgv.AGraph("graph.dot") | |
# graph.layout(prog='dot') | |
# graph.draw("graph.png") | |
# Perform inference | |
if model == "No data augmentation": | |
model_pred = model1 | |
model_pred.eval() | |
elif model == "Ensemble": | |
model_pred = model3 | |
model_pred.eval() | |
elif model == "Bootstrap ensemble": | |
model_pred = model4 | |
model_pred.eval() | |
elif model == "Random edge deletion": | |
model_pred = model5 | |
model_pred.eval() | |
elif model == "Hierarchy substitution": | |
model_pred = model6 | |
model_pred.eval() | |
elif model == "Adjusted class weights": | |
model_pred = model7 | |
model_pred.eval() | |
else: | |
model_pred = model2 | |
model_pred.eval() | |
glycan = [glycan] | |
label = [0] | |
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1))) | |
if model in ["Ensemble", "Bootstrap ensemble"]: | |
pred = model_pred(data) | |
else: | |
device = "cpu" | |
x = data.labels | |
edge_index = data.edge_index | |
batch = data.batch | |
x = x.to(device) | |
edge_index = edge_index.to(device) | |
batch = batch.to(device) | |
pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0] | |
pred = np.exp(pred)/sum(np.exp(pred)) # Softmax | |
pred = [float(x) for x in pred] | |
pred = {class_list[i]:pred[i] for i in range(15)} | |
return pred, "graph.png" | |
demo = gr.Interface( | |
fn=fn, | |
inputs=[gr.Textbox(label="Glycan sequence", value="Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Random edge deletion", "Ensemble", "Bootstrap ensemble", "Hierarchy substitution", "Adjusted class weights"])], | |
outputs=[gr.Label(num_top_classes=15, label="Prediction"), gr.Image(label="Glycan graph")], | |
allow_flagging="never", | |
title="SweetNet demo", | |
examples=[ | |
["Hex(z1-z)[GlcA(z1-z)]Hex(z1-z)Hex(z1-z)Hex", "Random node deletion"], | |
["Man(a1-2)Man(a1-3)[Man(a1-2)Man(a1-6)[Man(a1-3)]Man(a1-6)]Man(b1-3)GlcNAc", "No data augmentation"], | |
["Man(a1-2)Man(a1-2)Man(a1-3)[Glc(z1-z)Gal(a1-2)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Ensemble"], | |
["L-GulA(a1-4)ManA(b1-4)Man(b1-4)L-GulA(a1-4)L-GulA(a1-4)ManA", "Bootstrap Ensemble"], | |
["GlcNAc(b1-4)GlcNAc(b1-4)GlcNAc(b1-4)GlcNAc", "Random edge deletion"], | |
["Gal(b1-4)GlcNAc(b1-4)GlcNAc(b1-4)GlcNAc(b1-4)GlcNAc(b1-4)GlcNAc", "Adjusted class weights"], | |
] | |
) | |
demo.launch(debug=True) |