File size: 10,690 Bytes
20c9cc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.weight_norm import weight_norm
import math
import numpy as np
class cross_attn_block(nn.Module):
def __init__(self, embed_dim, n_heads, dropout):
super().__init__()
self.heads = n_heads
self.mha = nn.MultiheadAttention(embed_dim, n_heads, dropout, batch_first=True)
self.ln_apt = nn.LayerNorm(embed_dim)
self.ln_prot = nn.LayerNorm(embed_dim)
self.ln_out = nn.LayerNorm(embed_dim)
self.linear = nn.Linear(embed_dim, embed_dim)
def forward(self, embeddings_x, embeddings_y, x_t, y_t):
# compute attention masks
attn_mask = generate_3d_mask(y_t, x_t, self.heads)
# apply layer norms
embeddings_x_n = self.ln_apt(embeddings_x)
embeddings_y_n = self.ln_prot(embeddings_y)
# perform cross-attention
reps = embeddings_y + self.mha(embeddings_y_n, embeddings_x_n, embeddings_x_n, attn_mask=attn_mask)[0]
return reps + self.linear(self.ln_out(reps))
class self_attn_block(nn.Module):
def __init__(self, d_embed, heads, dropout):
super().__init__()
# self.l1 = nn.Linear(d_linear, d_linear)
self.heads = heads
self.ln1 = nn.LayerNorm(d_embed)
self.ln2 = nn.LayerNorm(d_embed)
self.mha = nn.MultiheadAttention(d_embed, self.heads, dropout, batch_first=True)
self.linear = nn.Linear(d_embed, d_embed)
def forward(self, embeddings_x, x_t):
# compute attention masks
# attn_mask = generate_3d_mask(x_t, x_t, self.heads)
# apply layer norm
embeddings_x_n = self.ln1(embeddings_x)
reps = embeddings_x + self.mha(embeddings_x_n, embeddings_x_n, embeddings_x_n, key_padding_mask=~x_t)[0]
return reps + self.linear(self.ln2(reps))
class AptaBLE(nn.Module):
def __init__(self, apta_encoder, prot_encoder, dropout):
super(AptaBLE, self).__init__()
#hyperparameters
self.apta_encoder = apta_encoder
self.prot_encoder = prot_encoder
self.flatten = nn.Flatten()
self.prot_reshape = nn.Linear(1280, 512)
self.apta_keep = nn.Linear(512, 512)
self.l1 = nn.Linear(1024, 1024)
self.l2 = nn.Linear(1024, 512)
self.l3 = nn.Linear(512, 256)
self.l4 = nn.Linear(256, 1)
self.can = CAN(512, 8, 1, 'mean_all_tok')
self.bn1 = nn.BatchNorm1d(1024)
self.bn2 = nn.BatchNorm1d(512)
self.bn3 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()
def forward(self, apta_in, esm_prot, apta_attn, prot_attn):
apta = self.apta_encoder(apta_in, apta_attn, apta_attn, output_hidden_states=True)['hidden_states'][-1] # output: (BS X #apt_toks x apt_embed_dim), encoder outputs (BS x MLM & sec. structure feature embeddings)
prot = self.prot_encoder(esm_prot, repr_layers=[33], return_contacts=False)['representations'][33]
prot = self.prot_reshape(prot)
apta = self.apta_keep(apta)
output, cross_map, prot_map, apta_map = self.can(prot, apta, prot_attn, apta_attn)
output = self.relu(self.l1(output))
output = self.bn1(output)
output = self.relu(self.l2(output))
output = self.bn2(output)
output = self.relu(self.l3(output))
output = self.bn3(output)
output = self.l4(output)
output = torch.sigmoid(output)
return output, cross_map, prot_map, apta_map
def find_opt_threshold(target, pred):
result = 0
best = 0
for i in range(0, 1000):
pred_threshold = np.where(pred > i/1000, 1, 0)
now = f1_score(target, pred_threshold)
if now > best:
result = i/1000
best = now
return result
def argument_seqset(seqset):
arg_seqset = []
for s, ss in seqset:
arg_seqset.append([s, ss])
arg_seqset.append([s[::-1], ss[::-1]])
return arg_seqset
def augment_apis(apta, prot, ys):
aug_apta = []
aug_prot = []
aug_y = []
for a, p, y in zip(apta, prot, ys):
aug_apta.append(a)
aug_prot.append(p)
aug_y.append(y)
aug_apta.append(a[::-1])
aug_prot.append(p)
aug_y.append(y)
aug_apta.append(a)
aug_prot.append(p[::-1])
aug_y.append(y)
aug_apta.append(a[::-1])
aug_prot.append(p[::-1])
aug_y.append(y)
return np.array(aug_apta), np.array(aug_prot), np.array(aug_y)
def generate_3d_mask(batch1, batch2, heads):
# Ensure the batches are tensors
batch1 = torch.tensor(batch1, dtype=torch.bool)
batch2 = torch.tensor(batch2, dtype=torch.bool)
# Validate that the batches have the same length
if batch1.size(0) != batch2.size(0):
raise ValueError("The batches must have the same number of vectors")
# Generate the 3D mask for each pair of vectors
out_mask = []
masks = torch.stack([torch.ger(vec1, vec2) for vec1, vec2 in zip(batch1, batch2)])
for j in range(masks.shape[0]):
out_mask.append(torch.stack([masks[j] for i in range(heads)]))
# out_mask = torch.tensor(out_mask, dtype=bool)
out_mask = torch.cat(out_mask)
# Replace False with -inf and True with 0
out_mask = out_mask.float() # Convert to float to allow -inf
out_mask[out_mask == 0] = -1e9
out_mask[out_mask == 1] = 0
return out_mask
class CAN(nn.Module):
def __init__(self, hidden_dim, num_heads, group_size, aggregation):
super(CAN, self).__init__()
self.aggregation = aggregation
self.group_size = group_size
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
# Protein weights
self.prot_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.prot_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.prot_val = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Aptamer weights
self.apta_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.apta_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.apta_val = nn.Linear(hidden_dim, hidden_dim, bias=False)
# linear
self.lp = nn.Linear(hidden_dim, hidden_dim)
def mask_logits(self, logits, mask_row, mask_col, inf=1e6):
N, L1, L2, H = logits.shape
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
# Ignore all padding tokens across both embeddings
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
# Set logit to -1e6 if masked
logits = torch.where(mask_pair, logits, logits - inf)
alpha = torch.softmax(logits, dim=2)
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
return alpha
def rearrange_heads(self, x, n_heads, n_ch):
# rearrange embedding for MHA
s = list(x.size())[:-1] + [n_heads, n_ch]
return x.view(*s)
def grouped_embeddings(self, x, mask, group_size):
N, L, D = x.shape
groups = L // group_size
# Average embeddings within each group
x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
# Ignore groups without any non-padding tokens
mask_grouped = mask.view(N, groups, group_size).any(dim=2)
return x_grouped, mask_grouped
def forward(self, protein, aptamer, mask_prot, mask_apta):
# Group embeddings before applying multi-head attention
protein_grouped, mask_prot_grouped = self.grouped_embeddings(protein, mask_prot, self.group_size)
apta_grouped, mask_apta_grouped = self.grouped_embeddings(aptamer, mask_apta, self.group_size)
# Compute queries, keys, values for both protein and aptamer after grouping
query_prot = self.rearrange_heads(self.prot_query(protein_grouped), self.num_heads, self.head_dim)
key_prot = self.rearrange_heads(self.prot_key(protein_grouped), self.num_heads, self.head_dim)
value_prot = self.rearrange_heads(self.prot_val(protein_grouped), self.num_heads, self.head_dim)
query_apta = self.rearrange_heads(self.apta_query(apta_grouped), self.num_heads, self.head_dim)
key_apta = self.rearrange_heads(self.apta_key(apta_grouped), self.num_heads, self.head_dim)
value_apta = self.rearrange_heads(self.apta_val(apta_grouped), self.num_heads, self.head_dim)
# Compute attention scores
logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
logits_pa = torch.einsum('blhd, bkhd->blkh', query_prot, key_apta)
logits_ap = torch.einsum('blhd, bkhd->blkh', query_apta, key_prot)
logits_aa = torch.einsum('blhd, bkhd->blkh', query_apta, key_apta)
ml_pp = self.mask_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
ml_pa = self.mask_logits(logits_pa, mask_prot_grouped, mask_apta_grouped)
ml_ap = self.mask_logits(logits_ap, mask_apta_grouped, mask_prot_grouped)
ml_aa = self.mask_logits(logits_aa, mask_apta_grouped, mask_apta_grouped)
# Combine heads, combine self-attended and cross-attended representations (via avg)
prot_embedding = (torch.einsum('blkh, bkhd->blhd', ml_pp, value_prot).flatten(-2) +
torch.einsum('blkh, bkhd->blhd', ml_pa, value_apta).flatten(-2)) / 2
apta_embedding = (torch.einsum('blkh, bkhd->blhd', ml_ap, value_prot).flatten(-2) +
torch.einsum('blkh, bkhd->blhd', ml_aa, value_apta).flatten(-2)) / 2
prot_embedding += protein
apta_embedding += aptamer
# Aggregate token representations
if self.aggregation == "cls":
prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
apta_embed = apta_embedding[:, 0] # query : [batch_size, hidden]
elif self.aggregation == "mean_all_tok":
prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
apta_embed = apta_embedding.mean(1) # query : [batch_size, hidden]
elif self.aggregation == "mean":
prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
apta_embed = (apta_embedding * mask_apta_grouped.unsqueeze(-1)).sum(1) / mask_apta_grouped.sum(-1).unsqueeze(-1)
else:
raise NotImplementedError()
embed = torch.cat([prot_embed, apta_embed], dim=1)
return embed, ml_pa, ml_pp, ml_aa
|