File size: 2,605 Bytes
e39cbff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4754bea
e39cbff
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
main model 
"""
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange
import os

from .utils import length_to_mask, load_model_and_alphabet_core


class EsmModel(nn.Module):
    def __init__(self, hidden_size=64, num_labels=2, projection_size=24, head=12):
        super().__init__()

        basedir = os.path.abspath(os.path.dirname(__file__))
        self.esm, self.alphabet = load_model_and_alphabet_core(os.path.join(basedir, 'args.pt'))
        self.num_labels = num_labels
        self.head = head
        self.hidden_size = hidden_size
        self.projection = nn.Linear(hidden_size, projection_size)
        self.cov_1 = nn.Conv1d(projection_size, projection_size, kernel_size=3, padding='same')
        self.cov_2 = nn.Conv1d(projection_size, int(projection_size/2), kernel_size=1, padding='same')
        # self.gating = nn.Linear(projection_size, projection_size)
        self.W = nn.Parameter(torch.randn((head, int(projection_size/2))))
        # self.mu = nn.Parameter(torch.randn((1, 768)))
        self.fcn = nn.Sequential(nn.Linear(int(projection_size/2)*head, int(projection_size/2)),
                                nn.ReLU(), nn.Linear(int(projection_size/2), num_labels))      
    

    def forward(self, peptide_list, device='cpu'):
        peptide_length = [len(i[1]) for i in peptide_list]
        batch_converter = self.alphabet.get_batch_converter()
        _, _, batch_tokens = batch_converter(peptide_list)
        batch_tokens = batch_tokens.to(device)
        protein_dict = self.esm(batch_tokens, repr_layers=[12], return_contacts=False)
        protein_embeddings = protein_dict["representations"][12][:, 1:, :]
        protein_embed = rearrange(protein_embeddings, 'b l (h d)-> (b h) l d', h=self.head)
        representations = self.projection(protein_embed)
        representations = rearrange(representations, 'b l d -> b d l')
        representation_cov = F.relu(self.cov_1(representations))
        representation_cov =  F.relu(self.cov_2(representation_cov))
        representations = rearrange(representation_cov, '(b h) d l -> b h l d', h=self.head)
        att = torch.einsum('bhld,hd->bhl', representations, self.W)
        mask = length_to_mask(torch.tensor(peptide_length)).to(device).int()
        att = att.masked_fill(mask.unsqueeze(1)==0, -np.inf)
        att= F.softmax(att, dim=-1)
        representations = rearrange(representations * att.unsqueeze(-1), 'b h l d -> b l (h d)')
        representations = torch.sum(representations, dim=1)
        return self.fcn(representations), att