File size: 1,458 Bytes
e39cbff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bf7082
e39cbff
fe55e9c
 
 
 
 
 
 
 
 
4754bea
 
 
 
 
 
 
fe55e9c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from .model import EsmModel
from .utils import load_hub_workaround
import torch

def predict(peptide_list, model_path, device='cpu'):
    with torch.no_grad():
        neuroPred_model = EsmModel()
        neuroPred_model.eval()
        # state_dict = load_hub_workaround(MODEL_URL)
        state_dict = torch.load(model_path, map_location="cpu")
        neuroPred_model.load_state_dict(state_dict)
        neuroPred_model = neuroPred_model.to(device)
        prob, att = neuroPred_model(peptide_list, device)
        pred = torch.softmax(prob, dim=-1).cpu().tolist()
        att = att.cpu().numpy()
        out = {'Neuropeptide':pred[0][1], "Non-neuropeptide":pred[0][0]}
    return out

def batch_predict(peptide_list, cutoff, model_path, device='cpu'):
    with torch.no_grad():
        neuroPred_model = EsmModel()
        neuroPred_model.eval()
        # state_dict = load_hub_workaround(MODEL_URL)
        state_dict = torch.load(model_path, map_location="cpu")
        neuroPred_model.load_state_dict(state_dict)
        neuroPred_model = neuroPred_model.to(device)
        out = []
        for item in peptide_list:
            prob, att = neuroPred_model([item], device)
            pred = torch.softmax(prob, dim=-1).cpu().tolist()
            att = att.cpu().numpy()
            temp = [[i[0], i[1], f"{j[1]:.3f}", 'Neuropeptide' if j[1] >cutoff else 'Non-neuropeptide'] for i, j in zip([item], pred)]
            out.append(temp[0])
    return out