File size: 6,765 Bytes
0cd961f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pdb
class ProteinGraph(nn.Module):
def __init__(self, d_node, d_edge, d_position):
super(ProteinGraph, self).__init__()
self.d_node = d_node
self.d_edge = d_edge
self.d_position = d_position
d_node_original = 1280 + 8 + d_position
self.node_mapping = nn.Linear(d_node_original, self.d_node)
self.linear_edge = nn.Linear(1, d_edge)
vhse8_values = {
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
'X': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
'B': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
}
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7, 'X': 24, 'B': 25}
self.vhse8_tensor = torch.zeros(26, 8)
for aa, values in vhse8_values.items():
aa_index = aa_to_idx[aa]
self.vhse8_tensor[aa_index] = torch.tensor(values)
self.vhse8_tensor.requires_grad = False
# self.position_embedding = nn.Embedding(seq_len, self.d_position)
# def one_hot_encoding(self, seq_len):
# positions = torch.arange(seq_len).unsqueeze(1)
# one_hot = torch.nn.functional.one_hot(positions, num_classes=seq_len).squeeze(1)
# return one_hot
def create_sinusoidal_embeddings(self, seq_len, d_position):
position = torch.arange(seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_position, 2) * -(math.log(10000.0) / d_position))
pe = torch.zeros(seq_len, d_position)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # shape: (1, seq_len, d_position)
return pe
def add_cls_eos(self, tensor):
modified_tensor = []
for row in tensor:
new_row = [0] # Start with 0 at the beginning
ones_indices = (row == 1).nonzero(as_tuple=True)[0]
if len(ones_indices) > 0:
# Add 2 before the first occurrence of 1
first_one_idx = ones_indices[0].item()
new_row.extend(row[:first_one_idx].tolist()) # Add elements before the first 1
new_row.append(2) # Add 2 before the first 1
new_row.extend(row[first_one_idx:].tolist()) # Add the rest of the row
else:
# No 1 in the row, add 2 at the end
new_row.extend(row.tolist())
new_row.append(2) # Add 2 at the end
modified_tensor.append(torch.tensor(new_row))
return torch.stack(modified_tensor)
def forward(self, tokens, esm, alphabet):
# pdb.set_trace()
batch_size, seq_len = tokens.size()
pad_mask = (tokens != alphabet.padding_idx).int() # B*L
device = tokens.device
# ESM-2 embedding
with torch.no_grad():
esm_results = esm(tokens, repr_layers=[33], return_contacts=True)
esm_embedding = esm_results["representations"][33] # shape: B*L*1280
esm_embedding = esm_embedding * pad_mask.unsqueeze(-1)
# VSHE embedding
vhse8_tensor = self.vhse8_tensor.to(device)
vshe8_embedding = vhse8_tensor[tokens]
# Sinual positional embedding
# pdb.set_trace()
sin_embedding = self.create_sinusoidal_embeddings(seq_len, self.d_position).repeat(batch_size, 1, 1).to(device) # shape: B*L*d_position
sin_embedding = sin_embedding * pad_mask.unsqueeze(-1)
# # One-hot position encoding
# one_hot = torch.stack((self.one_hot_encoding(seq_len),)*batch_size) # shape: B*L*L
# one_hot_embedding = self.position_embedding(one_hot.view(-1, seq_len)).view(batch_size, seq_len, -1) # shape: B*L*d_position
# one_hot_embedding = one_hot_embedding * pad_mask.unsqueeze(-1)
node_representation = torch.cat((esm_embedding, vshe8_embedding, sin_embedding), dim=-1) # B*L*(1280+8+d_position)
node_representation = self.node_mapping(node_representation) # B*L*d_node
# Edge represntation
with torch.no_grad():
esm_results = esm(self.add_cls_eos(tokens.cpu()).to(device), repr_layers=[33], return_contacts=True) # add <cls> and <eos> back to the tokens for predicting contact maps
# pdb.set_trace()
contact_map = esm_results["contacts"] # shape: B*L*L
edge_representation = self.linear_edge(contact_map.unsqueeze(-1)) # shape: B*L*L*d_edge
expanded_pad_mask = pad_mask.unsqueeze(1).expand(-1, seq_len, -1)
edge_representation = edge_representation * expanded_pad_mask.unsqueeze(-1)
# edge_representation = edge_representation * expanded_pad_mask.transpose(1,2).unsqueeze(-1)
# pdb.set_trace()
return node_representation, edge_representation, pad_mask, expanded_pad_mask
if __name__ == '__main__':
import esm
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
tokens = torch.tensor([[5,5,5,1], [5,6,7,8]])
seq_len = tokens.shape[1]
graph = ProteinGraph(1024, 512, 64)
node, edge, pad = graph(tokens, model, alphabet)
print(node.shape, edge.shape, pad.shape) |