Spaces:
Runtime error
Runtime error
File size: 1,458 Bytes
e39cbff 7bf7082 e39cbff fe55e9c 4754bea fe55e9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from .model import EsmModel
from .utils import load_hub_workaround
import torch
def predict(peptide_list, model_path, device='cpu'):
with torch.no_grad():
neuroPred_model = EsmModel()
neuroPred_model.eval()
# state_dict = load_hub_workaround(MODEL_URL)
state_dict = torch.load(model_path, map_location="cpu")
neuroPred_model.load_state_dict(state_dict)
neuroPred_model = neuroPred_model.to(device)
prob, att = neuroPred_model(peptide_list, device)
pred = torch.softmax(prob, dim=-1).cpu().tolist()
att = att.cpu().numpy()
out = {'Neuropeptide':pred[0][1], "Non-neuropeptide":pred[0][0]}
return out
def batch_predict(peptide_list, cutoff, model_path, device='cpu'):
with torch.no_grad():
neuroPred_model = EsmModel()
neuroPred_model.eval()
# state_dict = load_hub_workaround(MODEL_URL)
state_dict = torch.load(model_path, map_location="cpu")
neuroPred_model.load_state_dict(state_dict)
neuroPred_model = neuroPred_model.to(device)
out = []
for item in peptide_list:
prob, att = neuroPred_model([item], device)
pred = torch.softmax(prob, dim=-1).cpu().tolist()
att = att.cpu().numpy()
temp = [[i[0], i[1], f"{j[1]:.3f}", 'Neuropeptide' if j[1] >cutoff else 'Non-neuropeptide'] for i, j in zip([item], pred)]
out.append(temp[0])
return out |