Atom Bioworks commited on
Commit
20c9cc4
1 Parent(s): 25f05fc

Create encoders.py

Browse files
Files changed (1) hide show
  1. encoders.py +264 -0
encoders.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.weight_norm import weight_norm
5
+ import math
6
+ import numpy as np
7
+
8
+ class cross_attn_block(nn.Module):
9
+ def __init__(self, embed_dim, n_heads, dropout):
10
+ super().__init__()
11
+ self.heads = n_heads
12
+ self.mha = nn.MultiheadAttention(embed_dim, n_heads, dropout, batch_first=True)
13
+ self.ln_apt = nn.LayerNorm(embed_dim)
14
+ self.ln_prot = nn.LayerNorm(embed_dim)
15
+ self.ln_out = nn.LayerNorm(embed_dim)
16
+ self.linear = nn.Linear(embed_dim, embed_dim)
17
+
18
+ def forward(self, embeddings_x, embeddings_y, x_t, y_t):
19
+
20
+ # compute attention masks
21
+ attn_mask = generate_3d_mask(y_t, x_t, self.heads)
22
+
23
+ # apply layer norms
24
+ embeddings_x_n = self.ln_apt(embeddings_x)
25
+ embeddings_y_n = self.ln_prot(embeddings_y)
26
+
27
+ # perform cross-attention
28
+ reps = embeddings_y + self.mha(embeddings_y_n, embeddings_x_n, embeddings_x_n, attn_mask=attn_mask)[0]
29
+ return reps + self.linear(self.ln_out(reps))
30
+
31
+ class self_attn_block(nn.Module):
32
+ def __init__(self, d_embed, heads, dropout):
33
+ super().__init__()
34
+ # self.l1 = nn.Linear(d_linear, d_linear)
35
+ self.heads = heads
36
+ self.ln1 = nn.LayerNorm(d_embed)
37
+ self.ln2 = nn.LayerNorm(d_embed)
38
+ self.mha = nn.MultiheadAttention(d_embed, self.heads, dropout, batch_first=True)
39
+ self.linear = nn.Linear(d_embed, d_embed)
40
+
41
+ def forward(self, embeddings_x, x_t):
42
+
43
+ # compute attention masks
44
+ # attn_mask = generate_3d_mask(x_t, x_t, self.heads)
45
+ # apply layer norm
46
+ embeddings_x_n = self.ln1(embeddings_x)
47
+ reps = embeddings_x + self.mha(embeddings_x_n, embeddings_x_n, embeddings_x_n, key_padding_mask=~x_t)[0]
48
+ return reps + self.linear(self.ln2(reps))
49
+
50
+
51
+ class AptaBLE(nn.Module):
52
+ def __init__(self, apta_encoder, prot_encoder, dropout):
53
+ super(AptaBLE, self).__init__()
54
+
55
+ #hyperparameters
56
+ self.apta_encoder = apta_encoder
57
+ self.prot_encoder = prot_encoder
58
+
59
+ self.flatten = nn.Flatten()
60
+ self.prot_reshape = nn.Linear(1280, 512)
61
+ self.apta_keep = nn.Linear(512, 512)
62
+
63
+ self.l1 = nn.Linear(1024, 1024)
64
+ self.l2 = nn.Linear(1024, 512)
65
+ self.l3 = nn.Linear(512, 256)
66
+ self.l4 = nn.Linear(256, 1)
67
+ self.can = CAN(512, 8, 1, 'mean_all_tok')
68
+ self.bn1 = nn.BatchNorm1d(1024)
69
+ self.bn2 = nn.BatchNorm1d(512)
70
+ self.bn3 = nn.BatchNorm1d(256)
71
+ self.relu = nn.ReLU()
72
+
73
+
74
+
75
+ def forward(self, apta_in, esm_prot, apta_attn, prot_attn):
76
+ apta = self.apta_encoder(apta_in, apta_attn, apta_attn, output_hidden_states=True)['hidden_states'][-1] # output: (BS X #apt_toks x apt_embed_dim), encoder outputs (BS x MLM & sec. structure feature embeddings)
77
+
78
+ prot = self.prot_encoder(esm_prot, repr_layers=[33], return_contacts=False)['representations'][33]
79
+
80
+ prot = self.prot_reshape(prot)
81
+ apta = self.apta_keep(apta)
82
+
83
+ output, cross_map, prot_map, apta_map = self.can(prot, apta, prot_attn, apta_attn)
84
+ output = self.relu(self.l1(output))
85
+ output = self.bn1(output)
86
+ output = self.relu(self.l2(output))
87
+ output = self.bn2(output)
88
+ output = self.relu(self.l3(output))
89
+ output = self.bn3(output)
90
+ output = self.l4(output)
91
+ output = torch.sigmoid(output)
92
+
93
+ return output, cross_map, prot_map, apta_map
94
+
95
+ def find_opt_threshold(target, pred):
96
+ result = 0
97
+ best = 0
98
+
99
+ for i in range(0, 1000):
100
+ pred_threshold = np.where(pred > i/1000, 1, 0)
101
+ now = f1_score(target, pred_threshold)
102
+ if now > best:
103
+ result = i/1000
104
+ best = now
105
+
106
+ return result
107
+
108
+ def argument_seqset(seqset):
109
+ arg_seqset = []
110
+ for s, ss in seqset:
111
+ arg_seqset.append([s, ss])
112
+
113
+ arg_seqset.append([s[::-1], ss[::-1]])
114
+
115
+ return arg_seqset
116
+
117
+ def augment_apis(apta, prot, ys):
118
+ aug_apta = []
119
+ aug_prot = []
120
+ aug_y = []
121
+ for a, p, y in zip(apta, prot, ys):
122
+ aug_apta.append(a)
123
+ aug_prot.append(p)
124
+ aug_y.append(y)
125
+
126
+ aug_apta.append(a[::-1])
127
+ aug_prot.append(p)
128
+ aug_y.append(y)
129
+
130
+ aug_apta.append(a)
131
+ aug_prot.append(p[::-1])
132
+ aug_y.append(y)
133
+
134
+ aug_apta.append(a[::-1])
135
+ aug_prot.append(p[::-1])
136
+ aug_y.append(y)
137
+
138
+ return np.array(aug_apta), np.array(aug_prot), np.array(aug_y)
139
+
140
+ def generate_3d_mask(batch1, batch2, heads):
141
+ # Ensure the batches are tensors
142
+ batch1 = torch.tensor(batch1, dtype=torch.bool)
143
+ batch2 = torch.tensor(batch2, dtype=torch.bool)
144
+
145
+ # Validate that the batches have the same length
146
+ if batch1.size(0) != batch2.size(0):
147
+ raise ValueError("The batches must have the same number of vectors")
148
+
149
+ # Generate the 3D mask for each pair of vectors
150
+ out_mask = []
151
+ masks = torch.stack([torch.ger(vec1, vec2) for vec1, vec2 in zip(batch1, batch2)])
152
+ for j in range(masks.shape[0]):
153
+ out_mask.append(torch.stack([masks[j] for i in range(heads)]))
154
+ # out_mask = torch.tensor(out_mask, dtype=bool)
155
+ out_mask = torch.cat(out_mask)
156
+
157
+ # Replace False with -inf and True with 0
158
+ out_mask = out_mask.float() # Convert to float to allow -inf
159
+ out_mask[out_mask == 0] = -1e9
160
+ out_mask[out_mask == 1] = 0
161
+
162
+ return out_mask
163
+
164
+ class CAN(nn.Module):
165
+ def __init__(self, hidden_dim, num_heads, group_size, aggregation):
166
+ super(CAN, self).__init__()
167
+ self.aggregation = aggregation
168
+ self.group_size = group_size
169
+ self.hidden_dim = hidden_dim
170
+ self.num_heads = num_heads
171
+ self.head_dim = hidden_dim // num_heads
172
+
173
+ # Protein weights
174
+ self.prot_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
175
+ self.prot_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
176
+ self.prot_val = nn.Linear(hidden_dim, hidden_dim, bias=False)
177
+
178
+ # Aptamer weights
179
+ self.apta_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
180
+ self.apta_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
181
+ self.apta_val = nn.Linear(hidden_dim, hidden_dim, bias=False)
182
+
183
+ # linear
184
+ self.lp = nn.Linear(hidden_dim, hidden_dim)
185
+
186
+ def mask_logits(self, logits, mask_row, mask_col, inf=1e6):
187
+ N, L1, L2, H = logits.shape
188
+ mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
189
+ mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
190
+
191
+ # Ignore all padding tokens across both embeddings
192
+ mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
193
+
194
+ # Set logit to -1e6 if masked
195
+ logits = torch.where(mask_pair, logits, logits - inf)
196
+ alpha = torch.softmax(logits, dim=2)
197
+ mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
198
+ alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
199
+ return alpha
200
+
201
+ def rearrange_heads(self, x, n_heads, n_ch):
202
+ # rearrange embedding for MHA
203
+ s = list(x.size())[:-1] + [n_heads, n_ch]
204
+ return x.view(*s)
205
+
206
+ def grouped_embeddings(self, x, mask, group_size):
207
+ N, L, D = x.shape
208
+ groups = L // group_size
209
+ # Average embeddings within each group
210
+ x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
211
+ # Ignore groups without any non-padding tokens
212
+ mask_grouped = mask.view(N, groups, group_size).any(dim=2)
213
+ return x_grouped, mask_grouped
214
+
215
+ def forward(self, protein, aptamer, mask_prot, mask_apta):
216
+ # Group embeddings before applying multi-head attention
217
+ protein_grouped, mask_prot_grouped = self.grouped_embeddings(protein, mask_prot, self.group_size)
218
+ apta_grouped, mask_apta_grouped = self.grouped_embeddings(aptamer, mask_apta, self.group_size)
219
+
220
+ # Compute queries, keys, values for both protein and aptamer after grouping
221
+ query_prot = self.rearrange_heads(self.prot_query(protein_grouped), self.num_heads, self.head_dim)
222
+ key_prot = self.rearrange_heads(self.prot_key(protein_grouped), self.num_heads, self.head_dim)
223
+ value_prot = self.rearrange_heads(self.prot_val(protein_grouped), self.num_heads, self.head_dim)
224
+
225
+ query_apta = self.rearrange_heads(self.apta_query(apta_grouped), self.num_heads, self.head_dim)
226
+ key_apta = self.rearrange_heads(self.apta_key(apta_grouped), self.num_heads, self.head_dim)
227
+ value_apta = self.rearrange_heads(self.apta_val(apta_grouped), self.num_heads, self.head_dim)
228
+
229
+ # Compute attention scores
230
+ logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
231
+ logits_pa = torch.einsum('blhd, bkhd->blkh', query_prot, key_apta)
232
+ logits_ap = torch.einsum('blhd, bkhd->blkh', query_apta, key_prot)
233
+ logits_aa = torch.einsum('blhd, bkhd->blkh', query_apta, key_apta)
234
+
235
+ ml_pp = self.mask_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
236
+ ml_pa = self.mask_logits(logits_pa, mask_prot_grouped, mask_apta_grouped)
237
+ ml_ap = self.mask_logits(logits_ap, mask_apta_grouped, mask_prot_grouped)
238
+ ml_aa = self.mask_logits(logits_aa, mask_apta_grouped, mask_apta_grouped)
239
+
240
+ # Combine heads, combine self-attended and cross-attended representations (via avg)
241
+ prot_embedding = (torch.einsum('blkh, bkhd->blhd', ml_pp, value_prot).flatten(-2) +
242
+ torch.einsum('blkh, bkhd->blhd', ml_pa, value_apta).flatten(-2)) / 2
243
+ apta_embedding = (torch.einsum('blkh, bkhd->blhd', ml_ap, value_prot).flatten(-2) +
244
+ torch.einsum('blkh, bkhd->blhd', ml_aa, value_apta).flatten(-2)) / 2
245
+
246
+ prot_embedding += protein
247
+ apta_embedding += aptamer
248
+
249
+ # Aggregate token representations
250
+ if self.aggregation == "cls":
251
+ prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
252
+ apta_embed = apta_embedding[:, 0] # query : [batch_size, hidden]
253
+ elif self.aggregation == "mean_all_tok":
254
+ prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
255
+ apta_embed = apta_embedding.mean(1) # query : [batch_size, hidden]
256
+ elif self.aggregation == "mean":
257
+ prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
258
+ apta_embed = (apta_embedding * mask_apta_grouped.unsqueeze(-1)).sum(1) / mask_apta_grouped.sum(-1).unsqueeze(-1)
259
+ else:
260
+ raise NotImplementedError()
261
+
262
+ embed = torch.cat([prot_embed, apta_embed], dim=1)
263
+
264
+ return embed, ml_pa, ml_pp, ml_aa