from .model import EsmModel from .utils import load_hub_workaround import torch def predict(peptide_list, model_path, device='cpu'): with torch.no_grad(): neuroPred_model = EsmModel() neuroPred_model.eval() # state_dict = load_hub_workaround(MODEL_URL) state_dict = torch.load(model_path, map_location="cpu") neuroPred_model.load_state_dict(state_dict) neuroPred_model = neuroPred_model.to(device) prob, att = neuroPred_model(peptide_list, device) pred = torch.softmax(prob, dim=-1).cpu().tolist() att = att.cpu().numpy() out = {'Neuropeptide':pred[0][1], "Non-neuropeptide":pred[0][0]} return out def batch_predict(peptide_list, cutoff, model_path, device='cpu'): with torch.no_grad(): neuroPred_model = EsmModel() neuroPred_model.eval() # state_dict = load_hub_workaround(MODEL_URL) state_dict = torch.load(model_path, map_location="cpu") neuroPred_model.load_state_dict(state_dict) neuroPred_model = neuroPred_model.to(device) out = [] for item in peptide_list: prob, att = neuroPred_model([item], device) pred = torch.softmax(prob, dim=-1).cpu().tolist() att = att.cpu().numpy() temp = [[i[0], i[1], f"{j[1]:.3f}", 'Neuropeptide' if j[1] >cutoff else 'Non-neuropeptide'] for i, j in zip([item], pred)] out.append(temp[0]) return out