SweetNet / app.py
dalexanderch's picture
Upload app.py
44c8341
raw
history blame
No virus
1.4 kB
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
def fn(model, class_list):
def f(glycan):
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model(x,edge_index, batch).cpu().detach().numpy()
pred = np.argmax(pred)
pred = class_list[pred]
return pred
return f
model = torch.load("model.pt")
model.eval()
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
f = fn(model, class_list)
demo = gr.Interface(
fn=f,
inputs=[gr.Textbox(label="Glycan sequence")],
outputs=[gr.Textbox(label="Predicted Class")],
allow_flagging=False,
title="SweetNet demo",
examples=["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN",
"Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc",
"Neu5Ac(a2-3)Gal(b1-3)[Neu5Ac(a2-6)]GlcNAc(b1-3)Gal(b1-4)Glc-ol"]
)
demo.launch(debug=True)