wnagleiofficial
fix app batch
4754bea
from .model import EsmModel
from .utils import load_hub_workaround
import torch
def predict(peptide_list, model_path, device='cpu'):
with torch.no_grad():
neuroPred_model = EsmModel()
neuroPred_model.eval()
# state_dict = load_hub_workaround(MODEL_URL)
state_dict = torch.load(model_path, map_location="cpu")
neuroPred_model.load_state_dict(state_dict)
neuroPred_model = neuroPred_model.to(device)
prob, att = neuroPred_model(peptide_list, device)
pred = torch.softmax(prob, dim=-1).cpu().tolist()
att = att.cpu().numpy()
out = {'Neuropeptide':pred[0][1], "Non-neuropeptide":pred[0][0]}
return out
def batch_predict(peptide_list, cutoff, model_path, device='cpu'):
with torch.no_grad():
neuroPred_model = EsmModel()
neuroPred_model.eval()
# state_dict = load_hub_workaround(MODEL_URL)
state_dict = torch.load(model_path, map_location="cpu")
neuroPred_model.load_state_dict(state_dict)
neuroPred_model = neuroPred_model.to(device)
out = []
for item in peptide_list:
prob, att = neuroPred_model([item], device)
pred = torch.softmax(prob, dim=-1).cpu().tolist()
att = att.cpu().numpy()
temp = [[i[0], i[1], f"{j[1]:.3f}", 'Neuropeptide' if j[1] >cutoff else 'Non-neuropeptide'] for i, j in zip([item], pred)]
out.append(temp[0])
return out