Upload 31 files
Browse files- esm/__init__.py +12 -0
- esm/axial_attention.py +239 -0
- esm/constants.py +10 -0
- esm/data.py +493 -0
- esm/esmfold/v1/__init__.py +0 -0
- esm/esmfold/v1/categorical_mixture.py +43 -0
- esm/esmfold/v1/esmfold.py +364 -0
- esm/esmfold/v1/misc.py +309 -0
- esm/esmfold/v1/pretrained.py +181 -0
- esm/esmfold/v1/tri_self_attn_block.py +160 -0
- esm/esmfold/v1/trunk.py +243 -0
- esm/inverse_folding/__init__.py +8 -0
- esm/inverse_folding/features.py +352 -0
- esm/inverse_folding/gvp_encoder.py +56 -0
- esm/inverse_folding/gvp_modules.py +475 -0
- esm/inverse_folding/gvp_transformer.py +140 -0
- esm/inverse_folding/gvp_transformer_encoder.py +184 -0
- esm/inverse_folding/gvp_utils.py +68 -0
- esm/inverse_folding/multichain_util.py +152 -0
- esm/inverse_folding/transformer_decoder.py +228 -0
- esm/inverse_folding/transformer_layer.py +304 -0
- esm/inverse_folding/util.py +323 -0
- esm/model/__init__.py +1 -0
- esm/model/esm1.py +200 -0
- esm/model/esm2.py +147 -0
- esm/model/msa_transformer.py +238 -0
- esm/modules.py +418 -0
- esm/multihead_attention.py +508 -0
- esm/pretrained.py +552 -0
- esm/rotary_embedding.py +69 -0
- esm/version.py +6 -0
esm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .version import version as __version__ # noqa
|
7 |
+
|
8 |
+
from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
|
9 |
+
from .model.esm1 import ProteinBertModel # noqa
|
10 |
+
from .model.esm2 import ESM2 # noqa
|
11 |
+
from .model.msa_transformer import MSATransformer #noqa
|
12 |
+
from . import pretrained # noqa
|
esm/axial_attention.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class RowSelfAttention(nn.Module):
|
12 |
+
"""Compute self-attention over rows of a 2D input."""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
embed_dim,
|
17 |
+
num_heads,
|
18 |
+
dropout=0.0,
|
19 |
+
max_tokens_per_msa: int = 2 ** 16,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.num_heads = num_heads
|
23 |
+
self.dropout = dropout
|
24 |
+
self.head_dim = embed_dim // num_heads
|
25 |
+
self.scaling = self.head_dim ** -0.5
|
26 |
+
self.max_tokens_per_msa = max_tokens_per_msa
|
27 |
+
self.attn_shape = "hnij"
|
28 |
+
|
29 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
30 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
31 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
32 |
+
|
33 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
34 |
+
self.dropout_module = nn.Dropout(dropout)
|
35 |
+
|
36 |
+
def align_scaling(self, q):
|
37 |
+
num_rows = q.size(0)
|
38 |
+
return self.scaling / math.sqrt(num_rows)
|
39 |
+
|
40 |
+
def _batched_forward(
|
41 |
+
self,
|
42 |
+
x,
|
43 |
+
self_attn_mask=None,
|
44 |
+
self_attn_padding_mask=None,
|
45 |
+
):
|
46 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
47 |
+
max_rows = max(1, self.max_tokens_per_msa // num_cols)
|
48 |
+
attns = 0
|
49 |
+
scaling = self.align_scaling(x)
|
50 |
+
for start in range(0, num_rows, max_rows):
|
51 |
+
attn_weights = self.compute_attention_weights(
|
52 |
+
x[start : start + max_rows],
|
53 |
+
scaling,
|
54 |
+
self_attn_mask=self_attn_mask,
|
55 |
+
self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
|
56 |
+
if self_attn_padding_mask is not None
|
57 |
+
else None,
|
58 |
+
)
|
59 |
+
attns += attn_weights
|
60 |
+
attn_probs = attns.softmax(-1)
|
61 |
+
attn_probs = self.dropout_module(attn_probs)
|
62 |
+
|
63 |
+
outputs = []
|
64 |
+
for start in range(0, num_rows, max_rows):
|
65 |
+
output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
|
66 |
+
outputs.append(output)
|
67 |
+
|
68 |
+
output = torch.cat(outputs, 0)
|
69 |
+
return output, attn_probs
|
70 |
+
|
71 |
+
def compute_attention_weights(
|
72 |
+
self,
|
73 |
+
x,
|
74 |
+
scaling: float,
|
75 |
+
self_attn_mask=None,
|
76 |
+
self_attn_padding_mask=None,
|
77 |
+
):
|
78 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
79 |
+
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
80 |
+
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
81 |
+
q *= scaling
|
82 |
+
if self_attn_padding_mask is not None:
|
83 |
+
# Zero out any padded aligned positions - this is important since
|
84 |
+
# we take a sum across the alignment axis.
|
85 |
+
q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
|
86 |
+
|
87 |
+
attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
|
88 |
+
|
89 |
+
if self_attn_mask is not None:
|
90 |
+
raise NotImplementedError
|
91 |
+
# Mask Size: [B x R x C], Weights Size: [H x B x C x C]
|
92 |
+
|
93 |
+
if self_attn_padding_mask is not None:
|
94 |
+
attn_weights = attn_weights.masked_fill(
|
95 |
+
self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
|
96 |
+
-10000,
|
97 |
+
)
|
98 |
+
|
99 |
+
return attn_weights
|
100 |
+
|
101 |
+
def compute_attention_update(
|
102 |
+
self,
|
103 |
+
x,
|
104 |
+
attn_probs,
|
105 |
+
):
|
106 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
107 |
+
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
108 |
+
context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
|
109 |
+
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
|
110 |
+
output = self.out_proj(context)
|
111 |
+
return output
|
112 |
+
|
113 |
+
def forward(
|
114 |
+
self,
|
115 |
+
x,
|
116 |
+
self_attn_mask=None,
|
117 |
+
self_attn_padding_mask=None,
|
118 |
+
):
|
119 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
120 |
+
if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
|
121 |
+
return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
|
122 |
+
else:
|
123 |
+
scaling = self.align_scaling(x)
|
124 |
+
attn_weights = self.compute_attention_weights(
|
125 |
+
x, scaling, self_attn_mask, self_attn_padding_mask
|
126 |
+
)
|
127 |
+
attn_probs = attn_weights.softmax(-1)
|
128 |
+
attn_probs = self.dropout_module(attn_probs)
|
129 |
+
output = self.compute_attention_update(x, attn_probs)
|
130 |
+
return output, attn_probs
|
131 |
+
|
132 |
+
|
133 |
+
class ColumnSelfAttention(nn.Module):
|
134 |
+
"""Compute self-attention over columns of a 2D input."""
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
embed_dim,
|
139 |
+
num_heads,
|
140 |
+
dropout=0.0,
|
141 |
+
max_tokens_per_msa: int = 2 ** 16,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
self.num_heads = num_heads
|
146 |
+
self.dropout = dropout
|
147 |
+
self.head_dim = embed_dim // num_heads
|
148 |
+
self.scaling = self.head_dim ** -0.5
|
149 |
+
self.max_tokens_per_msa = max_tokens_per_msa
|
150 |
+
|
151 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
152 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
153 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
154 |
+
|
155 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
156 |
+
self.dropout_module = nn.Dropout(dropout)
|
157 |
+
|
158 |
+
def _batched_forward(
|
159 |
+
self,
|
160 |
+
x,
|
161 |
+
self_attn_mask=None,
|
162 |
+
self_attn_padding_mask=None,
|
163 |
+
):
|
164 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
165 |
+
max_cols = max(1, self.max_tokens_per_msa // num_rows)
|
166 |
+
outputs = []
|
167 |
+
attns = []
|
168 |
+
for start in range(0, num_cols, max_cols):
|
169 |
+
output, attn = self(
|
170 |
+
x[:, start : start + max_cols],
|
171 |
+
self_attn_mask=self_attn_mask,
|
172 |
+
self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
|
173 |
+
if self_attn_padding_mask is not None
|
174 |
+
else None,
|
175 |
+
)
|
176 |
+
outputs.append(output)
|
177 |
+
attns.append(attn)
|
178 |
+
output = torch.cat(outputs, 1)
|
179 |
+
attns = torch.cat(attns, 1)
|
180 |
+
return output, attns
|
181 |
+
|
182 |
+
def compute_attention_update(
|
183 |
+
self,
|
184 |
+
x,
|
185 |
+
self_attn_mask=None,
|
186 |
+
self_attn_padding_mask=None,
|
187 |
+
):
|
188 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
189 |
+
if num_rows == 1:
|
190 |
+
# if there is only 1 position, this is equivalent and doesn't break with padding
|
191 |
+
attn_probs = torch.ones(
|
192 |
+
self.num_heads,
|
193 |
+
num_cols,
|
194 |
+
batch_size,
|
195 |
+
num_rows,
|
196 |
+
num_rows,
|
197 |
+
device=x.device,
|
198 |
+
dtype=x.dtype,
|
199 |
+
)
|
200 |
+
output = self.out_proj(self.v_proj(x))
|
201 |
+
else:
|
202 |
+
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
203 |
+
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
204 |
+
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
|
205 |
+
q *= self.scaling
|
206 |
+
|
207 |
+
attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
|
208 |
+
|
209 |
+
if self_attn_mask is not None:
|
210 |
+
raise NotImplementedError
|
211 |
+
if self_attn_padding_mask is not None:
|
212 |
+
attn_weights = attn_weights.masked_fill(
|
213 |
+
self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
|
214 |
+
-10000,
|
215 |
+
)
|
216 |
+
|
217 |
+
attn_probs = attn_weights.softmax(-1)
|
218 |
+
attn_probs = self.dropout_module(attn_probs)
|
219 |
+
context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
|
220 |
+
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
|
221 |
+
output = self.out_proj(context)
|
222 |
+
return output, attn_probs
|
223 |
+
|
224 |
+
def forward(
|
225 |
+
self,
|
226 |
+
x,
|
227 |
+
self_attn_mask=None,
|
228 |
+
self_attn_padding_mask=None,
|
229 |
+
):
|
230 |
+
num_rows, num_cols, batch_size, embed_dim = x.size()
|
231 |
+
# if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
|
232 |
+
if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
|
233 |
+
return self._batched_forward(
|
234 |
+
x,
|
235 |
+
self_attn_mask,
|
236 |
+
self_attn_padding_mask,
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
|
esm/constants.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# fmt: off
|
7 |
+
proteinseq_toks = {
|
8 |
+
'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
|
9 |
+
}
|
10 |
+
# fmt: on
|
esm/data.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import os
|
8 |
+
from typing import Sequence, Tuple, List, Union
|
9 |
+
import pickle
|
10 |
+
import re
|
11 |
+
import shutil
|
12 |
+
import torch
|
13 |
+
from pathlib import Path
|
14 |
+
from esm.constants import proteinseq_toks
|
15 |
+
|
16 |
+
RawMSA = Sequence[Tuple[str, str]]
|
17 |
+
|
18 |
+
|
19 |
+
class FastaBatchedDataset(object):
|
20 |
+
def __init__(self, sequence_labels, sequence_strs):
|
21 |
+
self.sequence_labels = list(sequence_labels)
|
22 |
+
self.sequence_strs = list(sequence_strs)
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def from_file(cls, fasta_file):
|
26 |
+
sequence_labels, sequence_strs = [], []
|
27 |
+
cur_seq_label = None
|
28 |
+
buf = []
|
29 |
+
|
30 |
+
def _flush_current_seq():
|
31 |
+
nonlocal cur_seq_label, buf
|
32 |
+
if cur_seq_label is None:
|
33 |
+
return
|
34 |
+
sequence_labels.append(cur_seq_label)
|
35 |
+
sequence_strs.append("".join(buf))
|
36 |
+
cur_seq_label = None
|
37 |
+
buf = []
|
38 |
+
|
39 |
+
with open(fasta_file, "r") as infile:
|
40 |
+
for line_idx, line in enumerate(infile):
|
41 |
+
if line.startswith(">"): # label line
|
42 |
+
_flush_current_seq()
|
43 |
+
line = line[1:].strip()
|
44 |
+
if len(line) > 0:
|
45 |
+
cur_seq_label = line
|
46 |
+
else:
|
47 |
+
cur_seq_label = f"seqnum{line_idx:09d}"
|
48 |
+
else: # sequence line
|
49 |
+
buf.append(line.strip())
|
50 |
+
|
51 |
+
_flush_current_seq()
|
52 |
+
|
53 |
+
assert len(set(sequence_labels)) == len(
|
54 |
+
sequence_labels
|
55 |
+
), "Found duplicate sequence labels"
|
56 |
+
|
57 |
+
return cls(sequence_labels, sequence_strs)
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.sequence_labels)
|
61 |
+
|
62 |
+
def __getitem__(self, idx):
|
63 |
+
return self.sequence_labels[idx], self.sequence_strs[idx]
|
64 |
+
|
65 |
+
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
|
66 |
+
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
|
67 |
+
sizes.sort()
|
68 |
+
batches = []
|
69 |
+
buf = []
|
70 |
+
max_len = 0
|
71 |
+
|
72 |
+
def _flush_current_buf():
|
73 |
+
nonlocal max_len, buf
|
74 |
+
if len(buf) == 0:
|
75 |
+
return
|
76 |
+
batches.append(buf)
|
77 |
+
buf = []
|
78 |
+
max_len = 0
|
79 |
+
|
80 |
+
for sz, i in sizes:
|
81 |
+
sz += extra_toks_per_seq
|
82 |
+
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
|
83 |
+
_flush_current_buf()
|
84 |
+
max_len = max(max_len, sz)
|
85 |
+
buf.append(i)
|
86 |
+
|
87 |
+
_flush_current_buf()
|
88 |
+
return batches
|
89 |
+
|
90 |
+
|
91 |
+
class Alphabet(object):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
standard_toks: Sequence[str],
|
95 |
+
prepend_toks: Sequence[str] = ("<null_0>", "<pad>", "<eos>", "<unk>"),
|
96 |
+
append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"),
|
97 |
+
prepend_bos: bool = True,
|
98 |
+
append_eos: bool = False,
|
99 |
+
use_msa: bool = False,
|
100 |
+
):
|
101 |
+
self.standard_toks = list(standard_toks)
|
102 |
+
self.prepend_toks = list(prepend_toks)
|
103 |
+
self.append_toks = list(append_toks)
|
104 |
+
self.prepend_bos = prepend_bos
|
105 |
+
self.append_eos = append_eos
|
106 |
+
self.use_msa = use_msa
|
107 |
+
|
108 |
+
self.all_toks = list(self.prepend_toks)
|
109 |
+
self.all_toks.extend(self.standard_toks)
|
110 |
+
for i in range((8 - (len(self.all_toks) % 8)) % 8):
|
111 |
+
self.all_toks.append(f"<null_{i + 1}>")
|
112 |
+
self.all_toks.extend(self.append_toks)
|
113 |
+
|
114 |
+
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
|
115 |
+
|
116 |
+
self.unk_idx = self.tok_to_idx["<unk>"]
|
117 |
+
self.padding_idx = self.get_idx("<pad>")
|
118 |
+
self.cls_idx = self.get_idx("<cls>")
|
119 |
+
self.mask_idx = self.get_idx("<mask>")
|
120 |
+
self.eos_idx = self.get_idx("<eos>")
|
121 |
+
self.all_special_tokens = ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']
|
122 |
+
self.unique_no_split_tokens = self.all_toks
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return len(self.all_toks)
|
126 |
+
|
127 |
+
def get_idx(self, tok):
|
128 |
+
return self.tok_to_idx.get(tok, self.unk_idx)
|
129 |
+
|
130 |
+
def get_tok(self, ind):
|
131 |
+
return self.all_toks[ind]
|
132 |
+
|
133 |
+
def to_dict(self):
|
134 |
+
return self.tok_to_idx.copy()
|
135 |
+
|
136 |
+
def get_batch_converter(self, truncation_seq_length: int = None):
|
137 |
+
if self.use_msa:
|
138 |
+
return MSABatchConverter(self, truncation_seq_length)
|
139 |
+
else:
|
140 |
+
return BatchConverter(self, truncation_seq_length)
|
141 |
+
|
142 |
+
@classmethod
|
143 |
+
def from_architecture(cls, name: str) -> "Alphabet":
|
144 |
+
if name in ("ESM-1", "protein_bert_base"):
|
145 |
+
standard_toks = proteinseq_toks["toks"]
|
146 |
+
prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
|
147 |
+
append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
|
148 |
+
prepend_bos = True
|
149 |
+
append_eos = False
|
150 |
+
use_msa = False
|
151 |
+
elif name in ("ESM-1b", "roberta_large"):
|
152 |
+
standard_toks = proteinseq_toks["toks"]
|
153 |
+
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
154 |
+
append_toks = ("<mask>",)
|
155 |
+
prepend_bos = True
|
156 |
+
append_eos = True
|
157 |
+
use_msa = False
|
158 |
+
elif name in ("MSA Transformer", "msa_transformer"):
|
159 |
+
standard_toks = proteinseq_toks["toks"]
|
160 |
+
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
161 |
+
append_toks = ("<mask>",)
|
162 |
+
prepend_bos = True
|
163 |
+
append_eos = False
|
164 |
+
use_msa = True
|
165 |
+
elif "invariant_gvp" in name.lower():
|
166 |
+
standard_toks = proteinseq_toks["toks"]
|
167 |
+
prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>")
|
168 |
+
append_toks = ("<mask>", "<cath>", "<af2>")
|
169 |
+
prepend_bos = True
|
170 |
+
append_eos = False
|
171 |
+
use_msa = False
|
172 |
+
else:
|
173 |
+
raise ValueError("Unknown architecture selected")
|
174 |
+
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
|
175 |
+
|
176 |
+
def _tokenize(self, text) -> str:
|
177 |
+
return text.split()
|
178 |
+
|
179 |
+
def tokenize(self, text, **kwargs) -> List[str]:
|
180 |
+
"""
|
181 |
+
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
|
182 |
+
Converts a string in a sequence of tokens, using the tokenizer.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
text (:obj:`str`):
|
186 |
+
The sequence to be encoded.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
:obj:`List[str]`: The list of tokens.
|
190 |
+
"""
|
191 |
+
|
192 |
+
def split_on_token(tok, text):
|
193 |
+
result = []
|
194 |
+
split_text = text.split(tok)
|
195 |
+
for i, sub_text in enumerate(split_text):
|
196 |
+
# AddedToken can control whitespace stripping around them.
|
197 |
+
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
198 |
+
# Cf. https://github.com/huggingface/transformers/pull/2778
|
199 |
+
# and https://github.com/huggingface/transformers/issues/3788
|
200 |
+
# We strip left and right by default
|
201 |
+
if i < len(split_text) - 1:
|
202 |
+
sub_text = sub_text.rstrip()
|
203 |
+
if i > 0:
|
204 |
+
sub_text = sub_text.lstrip()
|
205 |
+
|
206 |
+
if i == 0 and not sub_text:
|
207 |
+
result.append(tok)
|
208 |
+
elif i == len(split_text) - 1:
|
209 |
+
if sub_text:
|
210 |
+
result.append(sub_text)
|
211 |
+
else:
|
212 |
+
pass
|
213 |
+
else:
|
214 |
+
if sub_text:
|
215 |
+
result.append(sub_text)
|
216 |
+
result.append(tok)
|
217 |
+
return result
|
218 |
+
|
219 |
+
def split_on_tokens(tok_list, text):
|
220 |
+
if not text.strip():
|
221 |
+
return []
|
222 |
+
|
223 |
+
tokenized_text = []
|
224 |
+
text_list = [text]
|
225 |
+
for tok in tok_list:
|
226 |
+
tokenized_text = []
|
227 |
+
for sub_text in text_list:
|
228 |
+
if sub_text not in self.unique_no_split_tokens:
|
229 |
+
tokenized_text.extend(split_on_token(tok, sub_text))
|
230 |
+
else:
|
231 |
+
tokenized_text.append(sub_text)
|
232 |
+
text_list = tokenized_text
|
233 |
+
|
234 |
+
return list(
|
235 |
+
itertools.chain.from_iterable(
|
236 |
+
(
|
237 |
+
self._tokenize(token)
|
238 |
+
if token not in self.unique_no_split_tokens
|
239 |
+
else [token]
|
240 |
+
for token in tokenized_text
|
241 |
+
)
|
242 |
+
)
|
243 |
+
)
|
244 |
+
|
245 |
+
no_split_token = self.unique_no_split_tokens
|
246 |
+
tokenized_text = split_on_tokens(no_split_token, text)
|
247 |
+
return tokenized_text
|
248 |
+
|
249 |
+
def encode(self, text):
|
250 |
+
return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
|
251 |
+
|
252 |
+
|
253 |
+
class BatchConverter(object):
|
254 |
+
"""Callable to convert an unprocessed (labels + strings) batch to a
|
255 |
+
processed (labels + tensor) batch.
|
256 |
+
"""
|
257 |
+
|
258 |
+
def __init__(self, alphabet, truncation_seq_length: int = None):
|
259 |
+
self.alphabet = alphabet
|
260 |
+
self.truncation_seq_length = truncation_seq_length
|
261 |
+
|
262 |
+
def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
|
263 |
+
# RoBERTa uses an eos token, while ESM-1 does not.
|
264 |
+
batch_size = len(raw_batch)
|
265 |
+
batch_labels, seq_str_list = zip(*raw_batch)
|
266 |
+
seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
|
267 |
+
if self.truncation_seq_length:
|
268 |
+
seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
|
269 |
+
max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
|
270 |
+
tokens = torch.empty(
|
271 |
+
(
|
272 |
+
batch_size,
|
273 |
+
max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
274 |
+
),
|
275 |
+
dtype=torch.int64,
|
276 |
+
)
|
277 |
+
tokens.fill_(self.alphabet.padding_idx)
|
278 |
+
labels = []
|
279 |
+
strs = []
|
280 |
+
|
281 |
+
for i, (label, seq_str, seq_encoded) in enumerate(
|
282 |
+
zip(batch_labels, seq_str_list, seq_encoded_list)
|
283 |
+
):
|
284 |
+
labels.append(label)
|
285 |
+
strs.append(seq_str)
|
286 |
+
if self.alphabet.prepend_bos:
|
287 |
+
tokens[i, 0] = self.alphabet.cls_idx
|
288 |
+
seq = torch.tensor(seq_encoded, dtype=torch.int64)
|
289 |
+
tokens[
|
290 |
+
i,
|
291 |
+
int(self.alphabet.prepend_bos) : len(seq_encoded)
|
292 |
+
+ int(self.alphabet.prepend_bos),
|
293 |
+
] = seq
|
294 |
+
if self.alphabet.append_eos:
|
295 |
+
tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
|
296 |
+
|
297 |
+
return labels, strs, tokens
|
298 |
+
|
299 |
+
|
300 |
+
class MSABatchConverter(BatchConverter):
|
301 |
+
def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
|
302 |
+
if isinstance(inputs[0][0], str):
|
303 |
+
# Input is a single MSA
|
304 |
+
raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
|
305 |
+
else:
|
306 |
+
raw_batch = inputs # type: ignore
|
307 |
+
|
308 |
+
batch_size = len(raw_batch)
|
309 |
+
max_alignments = max(len(msa) for msa in raw_batch)
|
310 |
+
max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
|
311 |
+
|
312 |
+
tokens = torch.empty(
|
313 |
+
(
|
314 |
+
batch_size,
|
315 |
+
max_alignments,
|
316 |
+
max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
317 |
+
),
|
318 |
+
dtype=torch.int64,
|
319 |
+
)
|
320 |
+
tokens.fill_(self.alphabet.padding_idx)
|
321 |
+
labels = []
|
322 |
+
strs = []
|
323 |
+
|
324 |
+
for i, msa in enumerate(raw_batch):
|
325 |
+
msa_seqlens = set(len(seq) for _, seq in msa)
|
326 |
+
if not len(msa_seqlens) == 1:
|
327 |
+
raise RuntimeError(
|
328 |
+
"Received unaligned sequences for input to MSA, all sequence "
|
329 |
+
"lengths must be equal."
|
330 |
+
)
|
331 |
+
msa_labels, msa_strs, msa_tokens = super().__call__(msa)
|
332 |
+
labels.append(msa_labels)
|
333 |
+
strs.append(msa_strs)
|
334 |
+
tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
|
335 |
+
|
336 |
+
return labels, strs, tokens
|
337 |
+
|
338 |
+
|
339 |
+
def read_fasta(
|
340 |
+
path,
|
341 |
+
keep_gaps=True,
|
342 |
+
keep_insertions=True,
|
343 |
+
to_upper=False,
|
344 |
+
):
|
345 |
+
with open(path, "r") as f:
|
346 |
+
for result in read_alignment_lines(
|
347 |
+
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
|
348 |
+
):
|
349 |
+
yield result
|
350 |
+
|
351 |
+
|
352 |
+
def read_alignment_lines(
|
353 |
+
lines,
|
354 |
+
keep_gaps=True,
|
355 |
+
keep_insertions=True,
|
356 |
+
to_upper=False,
|
357 |
+
):
|
358 |
+
seq = desc = None
|
359 |
+
|
360 |
+
def parse(s):
|
361 |
+
if not keep_gaps:
|
362 |
+
s = re.sub("-", "", s)
|
363 |
+
if not keep_insertions:
|
364 |
+
s = re.sub("[a-z]", "", s)
|
365 |
+
return s.upper() if to_upper else s
|
366 |
+
|
367 |
+
for line in lines:
|
368 |
+
# Line may be empty if seq % file_line_width == 0
|
369 |
+
if len(line) > 0 and line[0] == ">":
|
370 |
+
if seq is not None:
|
371 |
+
yield desc, parse(seq)
|
372 |
+
desc = line.strip().lstrip(">")
|
373 |
+
seq = ""
|
374 |
+
else:
|
375 |
+
assert isinstance(seq, str)
|
376 |
+
seq += line.strip()
|
377 |
+
assert isinstance(seq, str) and isinstance(desc, str)
|
378 |
+
yield desc, parse(seq)
|
379 |
+
|
380 |
+
|
381 |
+
class ESMStructuralSplitDataset(torch.utils.data.Dataset):
|
382 |
+
"""
|
383 |
+
Structural Split Dataset as described in section A.10 of the supplement of our paper.
|
384 |
+
https://doi.org/10.1101/622803
|
385 |
+
|
386 |
+
We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
|
387 |
+
generated on January 23, 2020.
|
388 |
+
|
389 |
+
For each SCOPe domain:
|
390 |
+
- We extract the sequence from the corresponding PDB file
|
391 |
+
- We extract the 3D coordinates of the Carbon beta atoms, aligning them
|
392 |
+
to the sequence. We put NaN where Cb atoms are missing.
|
393 |
+
- From the 3D coordinates, we calculate a pairwise distance map, based
|
394 |
+
on L2 distance
|
395 |
+
- We use DSSP to generate secondary structure labels for the corresponding
|
396 |
+
PDB file. This is also aligned to the sequence. We put - where SSP
|
397 |
+
labels are missing.
|
398 |
+
|
399 |
+
For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
|
400 |
+
we have split the data into 5 partitions for cross validation. These are provided
|
401 |
+
in a downloaded splits folder, in the format:
|
402 |
+
splits/{split_level}/{cv_partition}/{train|valid}.txt
|
403 |
+
where train is the partition and valid is the concatentation of the remaining 4.
|
404 |
+
|
405 |
+
For each SCOPe domain, we provide a pkl dump that contains:
|
406 |
+
- seq : The domain sequence, stored as an L-length string
|
407 |
+
- ssp : The secondary structure labels, stored as an L-length string
|
408 |
+
- dist : The distance map, stored as an LxL numpy array
|
409 |
+
- coords : The 3D coordinates, stored as an Lx3 numpy array
|
410 |
+
|
411 |
+
"""
|
412 |
+
|
413 |
+
base_folder = "structural-data"
|
414 |
+
file_list = [
|
415 |
+
# url tar filename filename MD5 Hash
|
416 |
+
(
|
417 |
+
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
|
418 |
+
"splits.tar.gz",
|
419 |
+
"splits",
|
420 |
+
"456fe1c7f22c9d3d8dfe9735da52411d",
|
421 |
+
),
|
422 |
+
(
|
423 |
+
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
|
424 |
+
"pkl.tar.gz",
|
425 |
+
"pkl",
|
426 |
+
"644ea91e56066c750cd50101d390f5db",
|
427 |
+
),
|
428 |
+
]
|
429 |
+
|
430 |
+
def __init__(
|
431 |
+
self,
|
432 |
+
split_level,
|
433 |
+
cv_partition,
|
434 |
+
split,
|
435 |
+
root_path=os.path.expanduser("~/.cache/torch/data/esm"),
|
436 |
+
download=False,
|
437 |
+
):
|
438 |
+
super().__init__()
|
439 |
+
assert split in [
|
440 |
+
"train",
|
441 |
+
"valid",
|
442 |
+
], "train_valid must be 'train' or 'valid'"
|
443 |
+
self.root_path = root_path
|
444 |
+
self.base_path = os.path.join(self.root_path, self.base_folder)
|
445 |
+
|
446 |
+
# check if root path has what you need or else download it
|
447 |
+
if download:
|
448 |
+
self.download()
|
449 |
+
|
450 |
+
self.split_file = os.path.join(
|
451 |
+
self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
|
452 |
+
)
|
453 |
+
self.pkl_dir = os.path.join(self.base_path, "pkl")
|
454 |
+
self.names = []
|
455 |
+
with open(self.split_file) as f:
|
456 |
+
self.names = f.read().splitlines()
|
457 |
+
|
458 |
+
def __len__(self):
|
459 |
+
return len(self.names)
|
460 |
+
|
461 |
+
def _check_exists(self) -> bool:
|
462 |
+
for (_, _, filename, _) in self.file_list:
|
463 |
+
fpath = os.path.join(self.base_path, filename)
|
464 |
+
if not os.path.exists(fpath) or not os.path.isdir(fpath):
|
465 |
+
return False
|
466 |
+
return True
|
467 |
+
|
468 |
+
def download(self):
|
469 |
+
|
470 |
+
if self._check_exists():
|
471 |
+
print("Files already downloaded and verified")
|
472 |
+
return
|
473 |
+
|
474 |
+
from torchvision.datasets.utils import download_url
|
475 |
+
|
476 |
+
for url, tar_filename, filename, md5_hash in self.file_list:
|
477 |
+
download_path = os.path.join(self.base_path, tar_filename)
|
478 |
+
download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
|
479 |
+
shutil.unpack_archive(download_path, self.base_path)
|
480 |
+
|
481 |
+
def __getitem__(self, idx):
|
482 |
+
"""
|
483 |
+
Returns a dict with the following entires
|
484 |
+
- seq : Str (domain sequence)
|
485 |
+
- ssp : Str (SSP labels)
|
486 |
+
- dist : np.array (distance map)
|
487 |
+
- coords : np.array (3D coordinates)
|
488 |
+
"""
|
489 |
+
name = self.names[idx]
|
490 |
+
pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
|
491 |
+
with open(pkl_fname, "rb") as f:
|
492 |
+
obj = pickle.load(f)
|
493 |
+
return obj
|
esm/esmfold/v1/__init__.py
ADDED
File without changes
|
esm/esmfold/v1/categorical_mixture.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class CategoricalMixture:
|
9 |
+
def __init__(self, param, bins=50, start=0, end=1):
|
10 |
+
# All tensors are of shape ..., bins.
|
11 |
+
self.logits = param
|
12 |
+
bins = torch.linspace(
|
13 |
+
start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype
|
14 |
+
)
|
15 |
+
self.v_bins = (bins[:-1] + bins[1:]) / 2
|
16 |
+
|
17 |
+
def log_prob(self, true):
|
18 |
+
# Shapes are:
|
19 |
+
# self.probs: ... x bins
|
20 |
+
# true : ...
|
21 |
+
true_index = (
|
22 |
+
(
|
23 |
+
true.unsqueeze(-1)
|
24 |
+
- self.v_bins[
|
25 |
+
[
|
26 |
+
None,
|
27 |
+
]
|
28 |
+
* true.ndim
|
29 |
+
]
|
30 |
+
)
|
31 |
+
.abs()
|
32 |
+
.argmin(-1)
|
33 |
+
)
|
34 |
+
nll = self.logits.log_softmax(-1)
|
35 |
+
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
|
36 |
+
|
37 |
+
def mean(self):
|
38 |
+
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
|
39 |
+
|
40 |
+
|
41 |
+
def categorical_lddt(logits, bins=50):
|
42 |
+
# Logits are ..., 37, bins.
|
43 |
+
return CategoricalMixture(logits, bins=bins).mean()
|
esm/esmfold/v1/esmfold.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import typing as T
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import LayerNorm
|
13 |
+
|
14 |
+
import esm
|
15 |
+
from esm import Alphabet
|
16 |
+
from esm.esmfold.v1.categorical_mixture import categorical_lddt
|
17 |
+
from esm.esmfold.v1.misc import (
|
18 |
+
batch_encode_sequences,
|
19 |
+
collate_dense_tensors,
|
20 |
+
output_to_pdb,
|
21 |
+
)
|
22 |
+
from esm.esmfold.v1.trunk import FoldingTrunk, FoldingTrunkConfig
|
23 |
+
from openfold.data.data_transforms import make_atom14_masks
|
24 |
+
from openfold.np import residue_constants
|
25 |
+
from openfold.utils.loss import compute_predicted_aligned_error, compute_tm
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class ESMFoldConfig:
|
30 |
+
trunk: T.Any = FoldingTrunkConfig()
|
31 |
+
lddt_head_hid_dim: int = 128
|
32 |
+
|
33 |
+
|
34 |
+
load_fn = esm.pretrained.load_model_and_alphabet
|
35 |
+
esm_registry = {
|
36 |
+
"esm2_8M": partial(load_fn, "esm2_t6_8M_UR50D_500K"),
|
37 |
+
"esm2_8M_270K": esm.pretrained.esm2_t6_8M_UR50D,
|
38 |
+
"esm2_35M": partial(load_fn, "esm2_t12_35M_UR50D_500K"),
|
39 |
+
"esm2_35M_270K": esm.pretrained.esm2_t12_35M_UR50D,
|
40 |
+
"esm2_150M": partial(load_fn, "esm2_t30_150M_UR50D_500K"),
|
41 |
+
"esm2_150M_270K": partial(load_fn, "esm2_t30_150M_UR50D_270K"),
|
42 |
+
"esm2_650M": esm.pretrained.esm2_t33_650M_UR50D,
|
43 |
+
"esm2_650M_270K": partial(load_fn, "esm2_t33_650M_270K_UR50D"),
|
44 |
+
"esm2_3B": esm.pretrained.esm2_t36_3B_UR50D,
|
45 |
+
"esm2_3B_270K": partial(load_fn, "esm2_t36_3B_UR50D_500K"),
|
46 |
+
"esm2_15B": esm.pretrained.esm2_t48_15B_UR50D,
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
class ESMFold(nn.Module):
|
51 |
+
def __init__(self, esmfold_config=None, **kwargs):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.cfg = esmfold_config if esmfold_config else ESMFoldConfig(**kwargs)
|
55 |
+
cfg = self.cfg
|
56 |
+
|
57 |
+
self.distogram_bins = 64
|
58 |
+
|
59 |
+
self.esm, self.esm_dict = esm_registry.get(cfg.esm_type)()
|
60 |
+
|
61 |
+
self.esm.requires_grad_(False)
|
62 |
+
self.esm.half()
|
63 |
+
|
64 |
+
self.esm_feats = self.esm.embed_dim
|
65 |
+
self.esm_attns = self.esm.num_layers * self.esm.attention_heads
|
66 |
+
self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict))
|
67 |
+
self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1))
|
68 |
+
|
69 |
+
c_s = cfg.trunk.sequence_state_dim
|
70 |
+
c_z = cfg.trunk.pairwise_state_dim
|
71 |
+
|
72 |
+
self.esm_s_mlp = nn.Sequential(
|
73 |
+
LayerNorm(self.esm_feats),
|
74 |
+
nn.Linear(self.esm_feats, c_s),
|
75 |
+
nn.ReLU(),
|
76 |
+
nn.Linear(c_s, c_s),
|
77 |
+
)
|
78 |
+
if cfg.use_esm_attn_map:
|
79 |
+
self.esm_z_mlp = nn.Sequential(
|
80 |
+
LayerNorm(self.esm_attns),
|
81 |
+
nn.Linear(self.esm_attns, c_z),
|
82 |
+
nn.ReLU(),
|
83 |
+
nn.Linear(c_z, c_z),
|
84 |
+
)
|
85 |
+
|
86 |
+
# 0 is padding, N is unknown residues, N + 1 is mask.
|
87 |
+
self.n_tokens_embed = residue_constants.restype_num + 3
|
88 |
+
self.pad_idx = 0
|
89 |
+
self.unk_idx = self.n_tokens_embed - 2
|
90 |
+
self.mask_idx = self.n_tokens_embed - 1
|
91 |
+
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
|
92 |
+
|
93 |
+
self.trunk = FoldingTrunk(**cfg.trunk)
|
94 |
+
|
95 |
+
self.distogram_head = nn.Linear(c_z, self.distogram_bins)
|
96 |
+
self.ptm_head = nn.Linear(c_z, self.distogram_bins)
|
97 |
+
self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
|
98 |
+
self.lddt_bins = 50
|
99 |
+
self.lddt_head = nn.Sequential(
|
100 |
+
nn.LayerNorm(cfg.trunk.structure_module.c_s),
|
101 |
+
nn.Linear(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim),
|
102 |
+
nn.Linear(cfg.lddt_head_hid_dim, cfg.lddt_head_hid_dim),
|
103 |
+
nn.Linear(cfg.lddt_head_hid_dim, 37 * self.lddt_bins),
|
104 |
+
)
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def _af2_to_esm(d: Alphabet):
|
108 |
+
# Remember that t is shifted from residue_constants by 1 (0 is padding).
|
109 |
+
esm_reorder = [d.padding_idx] + [
|
110 |
+
d.get_idx(v) for v in residue_constants.restypes_with_x
|
111 |
+
]
|
112 |
+
return torch.tensor(esm_reorder)
|
113 |
+
|
114 |
+
def _af2_idx_to_esm_idx(self, aa, mask):
|
115 |
+
aa = (aa + 1).masked_fill(mask != 1, 0)
|
116 |
+
return self.af2_to_esm[aa]
|
117 |
+
|
118 |
+
def _compute_language_model_representations(
|
119 |
+
self, esmaa: torch.Tensor
|
120 |
+
) -> torch.Tensor:
|
121 |
+
"""Adds bos/eos tokens for the language model, since the structure module doesn't use these."""
|
122 |
+
batch_size = esmaa.size(0)
|
123 |
+
|
124 |
+
bosi, eosi = self.esm_dict.cls_idx, self.esm_dict.eos_idx
|
125 |
+
bos = esmaa.new_full((batch_size, 1), bosi)
|
126 |
+
eos = esmaa.new_full((batch_size, 1), self.esm_dict.padding_idx)
|
127 |
+
esmaa = torch.cat([bos, esmaa, eos], dim=1)
|
128 |
+
# Use the first padding index as eos during inference.
|
129 |
+
esmaa[range(batch_size), (esmaa != 1).sum(1)] = eosi
|
130 |
+
|
131 |
+
res = self.esm(
|
132 |
+
esmaa,
|
133 |
+
repr_layers=range(self.esm.num_layers + 1),
|
134 |
+
need_head_weights=self.cfg.use_esm_attn_map,
|
135 |
+
)
|
136 |
+
esm_s = torch.stack(
|
137 |
+
[v for _, v in sorted(res["representations"].items())], dim=2
|
138 |
+
)
|
139 |
+
esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
|
140 |
+
esm_z = (
|
141 |
+
res["attentions"].permute(0, 4, 3, 1, 2).flatten(3, 4)[:, 1:-1, 1:-1, :]
|
142 |
+
if self.cfg.use_esm_attn_map
|
143 |
+
else None
|
144 |
+
)
|
145 |
+
return esm_s, esm_z
|
146 |
+
|
147 |
+
def _mask_inputs_to_esm(self, esmaa, pattern):
|
148 |
+
new_esmaa = esmaa.clone()
|
149 |
+
new_esmaa[pattern == 1] = self.esm_dict.mask_idx
|
150 |
+
return new_esmaa
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
aa: torch.Tensor,
|
155 |
+
mask: T.Optional[torch.Tensor] = None,
|
156 |
+
residx: T.Optional[torch.Tensor] = None,
|
157 |
+
masking_pattern: T.Optional[torch.Tensor] = None,
|
158 |
+
num_recycles: T.Optional[int] = None,
|
159 |
+
):
|
160 |
+
"""Runs a forward pass given input tokens. Use `model.infer` to
|
161 |
+
run inference from a sequence.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match
|
165 |
+
openfold.np.residue_constants.restype_order_with_x.
|
166 |
+
mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked.
|
167 |
+
residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
|
168 |
+
masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
|
169 |
+
as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
|
170 |
+
different masks are provided.
|
171 |
+
num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
|
172 |
+
recycles, which is 3.
|
173 |
+
"""
|
174 |
+
|
175 |
+
if mask is None:
|
176 |
+
mask = torch.ones_like(aa)
|
177 |
+
|
178 |
+
B = aa.shape[0]
|
179 |
+
L = aa.shape[1]
|
180 |
+
device = aa.device
|
181 |
+
|
182 |
+
if residx is None:
|
183 |
+
residx = torch.arange(L, device=device).expand_as(aa)
|
184 |
+
|
185 |
+
# === ESM ===
|
186 |
+
esmaa = self._af2_idx_to_esm_idx(aa, mask)
|
187 |
+
|
188 |
+
if masking_pattern is not None:
|
189 |
+
esmaa = self._mask_inputs_to_esm(esmaa, masking_pattern)
|
190 |
+
|
191 |
+
esm_s, esm_z = self._compute_language_model_representations(esmaa)
|
192 |
+
|
193 |
+
# Convert esm_s to the precision used by the trunk and
|
194 |
+
# the structure module. These tensors may be a lower precision if, for example,
|
195 |
+
# we're running the language model in fp16 precision.
|
196 |
+
esm_s = esm_s.to(self.esm_s_combine.dtype)
|
197 |
+
esm_s = esm_s.detach()
|
198 |
+
|
199 |
+
# === preprocessing ===
|
200 |
+
esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
|
201 |
+
|
202 |
+
s_s_0 = self.esm_s_mlp(esm_s)
|
203 |
+
if self.cfg.use_esm_attn_map:
|
204 |
+
esm_z = esm_z.to(self.esm_s_combine.dtype)
|
205 |
+
esm_z = esm_z.detach()
|
206 |
+
s_z_0 = self.esm_z_mlp(esm_z)
|
207 |
+
else:
|
208 |
+
s_z_0 = s_s_0.new_zeros(B, L, L, self.cfg.trunk.pairwise_state_dim)
|
209 |
+
|
210 |
+
s_s_0 += self.embedding(aa)
|
211 |
+
|
212 |
+
structure: dict = self.trunk(
|
213 |
+
s_s_0, s_z_0, aa, residx, mask, no_recycles=num_recycles
|
214 |
+
)
|
215 |
+
# Documenting what we expect:
|
216 |
+
structure = {
|
217 |
+
k: v
|
218 |
+
for k, v in structure.items()
|
219 |
+
if k
|
220 |
+
in [
|
221 |
+
"s_z",
|
222 |
+
"s_s",
|
223 |
+
"frames",
|
224 |
+
"sidechain_frames",
|
225 |
+
"unnormalized_angles",
|
226 |
+
"angles",
|
227 |
+
"positions",
|
228 |
+
"states",
|
229 |
+
]
|
230 |
+
}
|
231 |
+
|
232 |
+
disto_logits = self.distogram_head(structure["s_z"])
|
233 |
+
disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
|
234 |
+
structure["distogram_logits"] = disto_logits
|
235 |
+
|
236 |
+
lm_logits = self.lm_head(structure["s_s"])
|
237 |
+
structure["lm_logits"] = lm_logits
|
238 |
+
|
239 |
+
structure["aatype"] = aa
|
240 |
+
make_atom14_masks(structure)
|
241 |
+
|
242 |
+
for k in [
|
243 |
+
"atom14_atom_exists",
|
244 |
+
"atom37_atom_exists",
|
245 |
+
]:
|
246 |
+
structure[k] *= mask.unsqueeze(-1)
|
247 |
+
structure["residue_index"] = residx
|
248 |
+
|
249 |
+
lddt_head = self.lddt_head(structure["states"]).reshape(
|
250 |
+
structure["states"].shape[0], B, L, -1, self.lddt_bins
|
251 |
+
)
|
252 |
+
structure["lddt_head"] = lddt_head
|
253 |
+
plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
|
254 |
+
structure["plddt"] = (
|
255 |
+
100 * plddt
|
256 |
+
) # we predict plDDT between 0 and 1, scale to be between 0 and 100.
|
257 |
+
|
258 |
+
ptm_logits = self.ptm_head(structure["s_z"])
|
259 |
+
|
260 |
+
seqlen = mask.type(torch.int64).sum(1)
|
261 |
+
structure["ptm_logits"] = ptm_logits
|
262 |
+
structure["ptm"] = torch.stack(
|
263 |
+
[
|
264 |
+
compute_tm(
|
265 |
+
batch_ptm_logits[None, :sl, :sl],
|
266 |
+
max_bins=31,
|
267 |
+
no_bins=self.distogram_bins,
|
268 |
+
)
|
269 |
+
for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
|
270 |
+
]
|
271 |
+
)
|
272 |
+
structure.update(
|
273 |
+
compute_predicted_aligned_error(
|
274 |
+
ptm_logits, max_bin=31, no_bins=self.distogram_bins
|
275 |
+
)
|
276 |
+
)
|
277 |
+
|
278 |
+
return structure
|
279 |
+
|
280 |
+
@torch.no_grad()
|
281 |
+
def infer(
|
282 |
+
self,
|
283 |
+
sequences: T.Union[str, T.List[str]],
|
284 |
+
residx=None,
|
285 |
+
masking_pattern: T.Optional[torch.Tensor] = None,
|
286 |
+
num_recycles: T.Optional[int] = None,
|
287 |
+
residue_index_offset: T.Optional[int] = 512,
|
288 |
+
chain_linker: T.Optional[str] = "G" * 25,
|
289 |
+
):
|
290 |
+
"""Runs a forward pass given input sequences.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
|
294 |
+
each chain should be separated by a ':' token (e.g. "<chain1>:<chain2>:<chain3>").
|
295 |
+
residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
|
296 |
+
masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
|
297 |
+
as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
|
298 |
+
different masks are provided.
|
299 |
+
num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
|
300 |
+
recycles (cfg.trunk.max_recycles), which is 4.
|
301 |
+
residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on
|
302 |
+
single chain predictions. Default: 512.
|
303 |
+
chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
|
304 |
+
predictions. Default: length-25 poly-G ("G" * 25).
|
305 |
+
"""
|
306 |
+
if isinstance(sequences, str):
|
307 |
+
sequences = [sequences]
|
308 |
+
|
309 |
+
aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
|
310 |
+
sequences, residue_index_offset, chain_linker
|
311 |
+
)
|
312 |
+
|
313 |
+
if residx is None:
|
314 |
+
residx = _residx
|
315 |
+
elif not isinstance(residx, torch.Tensor):
|
316 |
+
residx = collate_dense_tensors(residx)
|
317 |
+
|
318 |
+
aatype, mask, residx, linker_mask = map(
|
319 |
+
lambda x: x.to(self.device), (aatype, mask, residx, linker_mask)
|
320 |
+
)
|
321 |
+
|
322 |
+
output = self.forward(
|
323 |
+
aatype,
|
324 |
+
mask=mask,
|
325 |
+
residx=residx,
|
326 |
+
masking_pattern=masking_pattern,
|
327 |
+
num_recycles=num_recycles,
|
328 |
+
)
|
329 |
+
|
330 |
+
output["atom37_atom_exists"] = output[
|
331 |
+
"atom37_atom_exists"
|
332 |
+
] * linker_mask.unsqueeze(2)
|
333 |
+
|
334 |
+
output["mean_plddt"] = (output["plddt"] * output["atom37_atom_exists"]).sum(
|
335 |
+
dim=(1, 2)
|
336 |
+
) / output["atom37_atom_exists"].sum(dim=(1, 2))
|
337 |
+
output["chain_index"] = chain_index
|
338 |
+
|
339 |
+
return output
|
340 |
+
|
341 |
+
def output_to_pdb(self, output: T.Dict) -> T.List[str]:
|
342 |
+
"""Returns the pbd (file) string from the model given the model output."""
|
343 |
+
return output_to_pdb(output)
|
344 |
+
|
345 |
+
def infer_pdbs(self, seqs: T.List[str], *args, **kwargs) -> T.List[str]:
|
346 |
+
"""Returns list of pdb (files) strings from the model given a list of input sequences."""
|
347 |
+
output = self.infer(seqs, *args, **kwargs)
|
348 |
+
return self.output_to_pdb(output)
|
349 |
+
|
350 |
+
def infer_pdb(self, sequence: str, *args, **kwargs) -> str:
|
351 |
+
"""Returns the pdb (file) string from the model given an input sequence."""
|
352 |
+
return self.infer_pdbs([sequence], *args, **kwargs)[0]
|
353 |
+
|
354 |
+
def set_chunk_size(self, chunk_size: T.Optional[int]):
|
355 |
+
# This parameter means the axial attention will be computed
|
356 |
+
# in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
|
357 |
+
# It's equivalent to running a for loop over chunks of the dimension we're iterative over,
|
358 |
+
# where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
|
359 |
+
# Setting the value to None will return to default behavior, disable chunking.
|
360 |
+
self.trunk.set_chunk_size(chunk_size)
|
361 |
+
|
362 |
+
@property
|
363 |
+
def device(self):
|
364 |
+
return self.esm_s_combine.device
|
esm/esmfold/v1/misc.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import typing as T
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from torch import nn
|
12 |
+
from openfold.np import residue_constants
|
13 |
+
from openfold.np.protein import Protein as OFProtein
|
14 |
+
from openfold.np.protein import to_pdb
|
15 |
+
from openfold.utils.feats import atom14_to_atom37
|
16 |
+
|
17 |
+
|
18 |
+
def encode_sequence(
|
19 |
+
seq: str,
|
20 |
+
residue_index_offset: T.Optional[int] = 512,
|
21 |
+
chain_linker: T.Optional[str] = "G" * 25,
|
22 |
+
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
23 |
+
if chain_linker is None:
|
24 |
+
chain_linker = ""
|
25 |
+
if residue_index_offset is None:
|
26 |
+
residue_index_offset = 0
|
27 |
+
|
28 |
+
chains = seq.split(":")
|
29 |
+
seq = chain_linker.join(chains)
|
30 |
+
|
31 |
+
unk_idx = residue_constants.restype_order_with_x["X"]
|
32 |
+
encoded = torch.tensor(
|
33 |
+
[residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq]
|
34 |
+
)
|
35 |
+
residx = torch.arange(len(encoded))
|
36 |
+
|
37 |
+
if residue_index_offset > 0:
|
38 |
+
start = 0
|
39 |
+
for i, chain in enumerate(chains):
|
40 |
+
residx[start : start + len(chain) + len(chain_linker)] += (
|
41 |
+
i * residue_index_offset
|
42 |
+
)
|
43 |
+
start += len(chain) + len(chain_linker)
|
44 |
+
|
45 |
+
linker_mask = torch.ones_like(encoded, dtype=torch.float32)
|
46 |
+
chain_index = []
|
47 |
+
offset = 0
|
48 |
+
for i, chain in enumerate(chains):
|
49 |
+
if i > 0:
|
50 |
+
chain_index.extend([i - 1] * len(chain_linker))
|
51 |
+
chain_index.extend([i] * len(chain))
|
52 |
+
offset += len(chain)
|
53 |
+
linker_mask[offset : offset + len(chain_linker)] = 0
|
54 |
+
offset += len(chain_linker)
|
55 |
+
|
56 |
+
chain_index = torch.tensor(chain_index, dtype=torch.int64)
|
57 |
+
|
58 |
+
return encoded, residx, linker_mask, chain_index
|
59 |
+
|
60 |
+
|
61 |
+
def batch_encode_sequences(
|
62 |
+
sequences: T.Sequence[str],
|
63 |
+
residue_index_offset: T.Optional[int] = 512,
|
64 |
+
chain_linker: T.Optional[str] = "G" * 25,
|
65 |
+
) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
66 |
+
|
67 |
+
aatype_list = []
|
68 |
+
residx_list = []
|
69 |
+
linker_mask_list = []
|
70 |
+
chain_index_list = []
|
71 |
+
for seq in sequences:
|
72 |
+
aatype_seq, residx_seq, linker_mask_seq, chain_index_seq = encode_sequence(
|
73 |
+
seq,
|
74 |
+
residue_index_offset=residue_index_offset,
|
75 |
+
chain_linker=chain_linker,
|
76 |
+
)
|
77 |
+
aatype_list.append(aatype_seq)
|
78 |
+
residx_list.append(residx_seq)
|
79 |
+
linker_mask_list.append(linker_mask_seq)
|
80 |
+
chain_index_list.append(chain_index_seq)
|
81 |
+
|
82 |
+
aatype = collate_dense_tensors(aatype_list)
|
83 |
+
mask = collate_dense_tensors(
|
84 |
+
[aatype.new_ones(len(aatype_seq)) for aatype_seq in aatype_list]
|
85 |
+
)
|
86 |
+
residx = collate_dense_tensors(residx_list)
|
87 |
+
linker_mask = collate_dense_tensors(linker_mask_list)
|
88 |
+
chain_index_list = collate_dense_tensors(chain_index_list, -1)
|
89 |
+
|
90 |
+
return aatype, mask, residx, linker_mask, chain_index_list
|
91 |
+
|
92 |
+
|
93 |
+
def output_to_pdb(output: T.Dict) -> T.List[str]:
|
94 |
+
"""Returns the pbd (file) string from the model given the model output."""
|
95 |
+
# atom14_to_atom37 must be called first, as it fails on latest numpy if the
|
96 |
+
# input is a numpy array. It will work if the input is a torch tensor.
|
97 |
+
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
|
98 |
+
output = {k: v.to("cpu").numpy() for k, v in output.items()}
|
99 |
+
final_atom_positions = final_atom_positions.cpu().numpy()
|
100 |
+
final_atom_mask = output["atom37_atom_exists"]
|
101 |
+
pdbs = []
|
102 |
+
for i in range(output["aatype"].shape[0]):
|
103 |
+
aa = output["aatype"][i]
|
104 |
+
pred_pos = final_atom_positions[i]
|
105 |
+
mask = final_atom_mask[i]
|
106 |
+
resid = output["residue_index"][i] + 1
|
107 |
+
pred = OFProtein(
|
108 |
+
aatype=aa,
|
109 |
+
atom_positions=pred_pos,
|
110 |
+
atom_mask=mask,
|
111 |
+
residue_index=resid,
|
112 |
+
b_factors=output["plddt"][i],
|
113 |
+
chain_index=output["chain_index"][i] if "chain_index" in output else None,
|
114 |
+
)
|
115 |
+
pdbs.append(to_pdb(pred))
|
116 |
+
return pdbs
|
117 |
+
|
118 |
+
|
119 |
+
def collate_dense_tensors(
|
120 |
+
samples: T.List[torch.Tensor], pad_v: float = 0
|
121 |
+
) -> torch.Tensor:
|
122 |
+
"""
|
123 |
+
Takes a list of tensors with the following dimensions:
|
124 |
+
[(d_11, ..., d_1K),
|
125 |
+
(d_21, ..., d_2K),
|
126 |
+
...,
|
127 |
+
(d_N1, ..., d_NK)]
|
128 |
+
and stack + pads them into a single tensor of:
|
129 |
+
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
|
130 |
+
"""
|
131 |
+
if len(samples) == 0:
|
132 |
+
return torch.Tensor()
|
133 |
+
if len(set(x.dim() for x in samples)) != 1:
|
134 |
+
raise RuntimeError(
|
135 |
+
f"Samples has varying dimensions: {[x.dim() for x in samples]}"
|
136 |
+
)
|
137 |
+
(device,) = tuple(set(x.device for x in samples)) # assumes all on same device
|
138 |
+
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
|
139 |
+
result = torch.empty(
|
140 |
+
len(samples), *max_shape, dtype=samples[0].dtype, device=device
|
141 |
+
)
|
142 |
+
result.fill_(pad_v)
|
143 |
+
for i in range(len(samples)):
|
144 |
+
result_i = result[i]
|
145 |
+
t = samples[i]
|
146 |
+
result_i[tuple(slice(0, k) for k in t.shape)] = t
|
147 |
+
return result
|
148 |
+
|
149 |
+
|
150 |
+
class Attention(nn.Module):
|
151 |
+
def __init__(self, embed_dim, num_heads, head_width, gated=False):
|
152 |
+
super().__init__()
|
153 |
+
assert embed_dim == num_heads * head_width
|
154 |
+
|
155 |
+
self.embed_dim = embed_dim
|
156 |
+
self.num_heads = num_heads
|
157 |
+
self.head_width = head_width
|
158 |
+
|
159 |
+
self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
160 |
+
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
161 |
+
self.gated = gated
|
162 |
+
if gated:
|
163 |
+
self.g_proj = nn.Linear(embed_dim, embed_dim)
|
164 |
+
torch.nn.init.zeros_(self.g_proj.weight)
|
165 |
+
torch.nn.init.ones_(self.g_proj.bias)
|
166 |
+
|
167 |
+
self.rescale_factor = self.head_width**-0.5
|
168 |
+
|
169 |
+
torch.nn.init.zeros_(self.o_proj.bias)
|
170 |
+
|
171 |
+
def forward(self, x, mask=None, bias=None, indices=None):
|
172 |
+
"""
|
173 |
+
Basic self attention with optional mask and external pairwise bias.
|
174 |
+
To handle sequences of different lengths, use mask.
|
175 |
+
|
176 |
+
Inputs:
|
177 |
+
x: batch of input sequneces (.. x L x C)
|
178 |
+
mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional.
|
179 |
+
bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional.
|
180 |
+
|
181 |
+
Outputs:
|
182 |
+
sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
|
183 |
+
"""
|
184 |
+
|
185 |
+
t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads)
|
186 |
+
q, k, v = t.chunk(3, dim=-1)
|
187 |
+
|
188 |
+
q = self.rescale_factor * q
|
189 |
+
a = torch.einsum("...qc,...kc->...qk", q, k)
|
190 |
+
|
191 |
+
# Add external attention bias.
|
192 |
+
if bias is not None:
|
193 |
+
a = a + rearrange(bias, "... lq lk h -> ... h lq lk")
|
194 |
+
|
195 |
+
# Do not attend to padding tokens.
|
196 |
+
if mask is not None:
|
197 |
+
mask = repeat(
|
198 |
+
mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2]
|
199 |
+
)
|
200 |
+
a = a.masked_fill(mask == False, -np.inf)
|
201 |
+
|
202 |
+
a = F.softmax(a, dim=-1)
|
203 |
+
|
204 |
+
y = torch.einsum("...hqk,...hkc->...qhc", a, v)
|
205 |
+
y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads)
|
206 |
+
|
207 |
+
if self.gated:
|
208 |
+
y = self.g_proj(x).sigmoid() * y
|
209 |
+
y = self.o_proj(y)
|
210 |
+
|
211 |
+
return y, rearrange(a, "... lq lk h -> ... h lq lk")
|
212 |
+
|
213 |
+
|
214 |
+
class Dropout(nn.Module):
|
215 |
+
"""
|
216 |
+
Implementation of dropout with the ability to share the dropout mask
|
217 |
+
along a particular dimension.
|
218 |
+
"""
|
219 |
+
|
220 |
+
def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]):
|
221 |
+
super(Dropout, self).__init__()
|
222 |
+
|
223 |
+
self.r = r
|
224 |
+
if type(batch_dim) == int:
|
225 |
+
batch_dim = [batch_dim]
|
226 |
+
self.batch_dim = batch_dim
|
227 |
+
self.dropout = nn.Dropout(self.r)
|
228 |
+
|
229 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
230 |
+
shape = list(x.shape)
|
231 |
+
if self.batch_dim is not None:
|
232 |
+
for bd in self.batch_dim:
|
233 |
+
shape[bd] = 1
|
234 |
+
return x * self.dropout(x.new_ones(shape))
|
235 |
+
|
236 |
+
|
237 |
+
class SequenceToPair(nn.Module):
|
238 |
+
def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
|
239 |
+
super().__init__()
|
240 |
+
|
241 |
+
self.layernorm = nn.LayerNorm(sequence_state_dim)
|
242 |
+
self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
|
243 |
+
self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
|
244 |
+
|
245 |
+
torch.nn.init.zeros_(self.proj.bias)
|
246 |
+
torch.nn.init.zeros_(self.o_proj.bias)
|
247 |
+
|
248 |
+
def forward(self, sequence_state):
|
249 |
+
"""
|
250 |
+
Inputs:
|
251 |
+
sequence_state: B x L x sequence_state_dim
|
252 |
+
|
253 |
+
Output:
|
254 |
+
pairwise_state: B x L x L x pairwise_state_dim
|
255 |
+
|
256 |
+
Intermediate state:
|
257 |
+
B x L x L x 2*inner_dim
|
258 |
+
"""
|
259 |
+
|
260 |
+
assert len(sequence_state.shape) == 3
|
261 |
+
|
262 |
+
s = self.layernorm(sequence_state)
|
263 |
+
s = self.proj(s)
|
264 |
+
q, k = s.chunk(2, dim=-1)
|
265 |
+
|
266 |
+
prod = q[:, None, :, :] * k[:, :, None, :]
|
267 |
+
diff = q[:, None, :, :] - k[:, :, None, :]
|
268 |
+
|
269 |
+
x = torch.cat([prod, diff], dim=-1)
|
270 |
+
x = self.o_proj(x)
|
271 |
+
|
272 |
+
return x
|
273 |
+
|
274 |
+
|
275 |
+
class PairToSequence(nn.Module):
|
276 |
+
def __init__(self, pairwise_state_dim, num_heads):
|
277 |
+
super().__init__()
|
278 |
+
|
279 |
+
self.layernorm = nn.LayerNorm(pairwise_state_dim)
|
280 |
+
self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
|
281 |
+
|
282 |
+
def forward(self, pairwise_state):
|
283 |
+
"""
|
284 |
+
Inputs:
|
285 |
+
pairwise_state: B x L x L x pairwise_state_dim
|
286 |
+
|
287 |
+
Output:
|
288 |
+
pairwise_bias: B x L x L x num_heads
|
289 |
+
"""
|
290 |
+
assert len(pairwise_state.shape) == 4
|
291 |
+
z = self.layernorm(pairwise_state)
|
292 |
+
pairwise_bias = self.linear(z)
|
293 |
+
return pairwise_bias
|
294 |
+
|
295 |
+
|
296 |
+
class ResidueMLP(nn.Module):
|
297 |
+
def __init__(self, embed_dim, inner_dim, norm=nn.LayerNorm, dropout=0):
|
298 |
+
super().__init__()
|
299 |
+
|
300 |
+
self.mlp = nn.Sequential(
|
301 |
+
norm(embed_dim),
|
302 |
+
nn.Linear(embed_dim, inner_dim),
|
303 |
+
nn.ReLU(),
|
304 |
+
nn.Linear(inner_dim, embed_dim),
|
305 |
+
nn.Dropout(dropout),
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(self, x):
|
309 |
+
return x + self.mlp(x)
|
esm/esmfold/v1/pretrained.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from esm.esmfold.v1.esmfold import ESMFold
|
11 |
+
|
12 |
+
|
13 |
+
def _load_model(model_name):
|
14 |
+
if model_name.endswith(".pt"): # local, treat as filepath
|
15 |
+
model_path = Path(model_name)
|
16 |
+
model_data = torch.load(str(model_path), map_location="cpu")
|
17 |
+
else: # load from hub
|
18 |
+
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
|
19 |
+
model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
|
20 |
+
|
21 |
+
cfg = model_data["cfg"]["model"]
|
22 |
+
model_state = model_data["model"]
|
23 |
+
model = ESMFold(esmfold_config=cfg)
|
24 |
+
|
25 |
+
expected_keys = set(model.state_dict().keys())
|
26 |
+
found_keys = set(model_state.keys())
|
27 |
+
|
28 |
+
missing_essential_keys = []
|
29 |
+
for missing_key in expected_keys - found_keys:
|
30 |
+
if not missing_key.startswith("esm."):
|
31 |
+
missing_essential_keys.append(missing_key)
|
32 |
+
|
33 |
+
if missing_essential_keys:
|
34 |
+
raise RuntimeError(f"Keys '{', '.join(missing_essential_keys)}' are missing.")
|
35 |
+
|
36 |
+
model.load_state_dict(model_state, strict=False)
|
37 |
+
|
38 |
+
return model
|
39 |
+
|
40 |
+
|
41 |
+
def esmfold_v0():
|
42 |
+
"""
|
43 |
+
ESMFold v0 model with 3B ESM-2, 48 folding blocks.
|
44 |
+
This version was used for the paper (Lin et al, 2022). It was trained
|
45 |
+
on all PDB chains until 2020-05, to ensure temporal holdout with CASP14
|
46 |
+
and the CAMEO validation and test set reported there.
|
47 |
+
"""
|
48 |
+
return _load_model("esmfold_3B_v0")
|
49 |
+
|
50 |
+
|
51 |
+
def esmfold_v1():
|
52 |
+
"""
|
53 |
+
ESMFold v1 model using 3B ESM-2, 48 folding blocks.
|
54 |
+
ESMFold provides fast high accuracy atomic level structure prediction
|
55 |
+
directly from the individual sequence of a protein. ESMFold uses the ESM2
|
56 |
+
protein language model to extract meaningful representations from the
|
57 |
+
protein sequence.
|
58 |
+
"""
|
59 |
+
return _load_model("esmfold_3B_v1")
|
60 |
+
|
61 |
+
|
62 |
+
def esmfold_structure_module_only_8M():
|
63 |
+
"""
|
64 |
+
ESMFold baseline model using 8M ESM-2, 0 folding blocks.
|
65 |
+
ESM-2 here is trained out to 500K updates.
|
66 |
+
This is a model designed to test the capabilities of the language model
|
67 |
+
when ablated for number of parameters in the language model.
|
68 |
+
See table S1 in (Lin et al, 2022).
|
69 |
+
"""
|
70 |
+
return _load_model("esmfold_structure_module_only_8M")
|
71 |
+
|
72 |
+
|
73 |
+
def esmfold_structure_module_only_8M_270K():
|
74 |
+
"""
|
75 |
+
ESMFold baseline model using 8M ESM-2, 0 folding blocks.
|
76 |
+
ESM-2 here is trained out to 270K updates.
|
77 |
+
This is a model designed to test the capabilities of the language model
|
78 |
+
when ablated for number of parameters in the language model.
|
79 |
+
See table S1 in (Lin et al, 2022).
|
80 |
+
"""
|
81 |
+
return _load_model("esmfold_structure_module_only_8M_270K")
|
82 |
+
|
83 |
+
|
84 |
+
def esmfold_structure_module_only_35M():
|
85 |
+
"""
|
86 |
+
ESMFold baseline model using 35M ESM-2, 0 folding blocks.
|
87 |
+
ESM-2 here is trained out to 500K updates.
|
88 |
+
This is a model designed to test the capabilities of the language model
|
89 |
+
when ablated for number of parameters in the language model.
|
90 |
+
See table S1 in (Lin et al, 2022).
|
91 |
+
"""
|
92 |
+
return _load_model("esmfold_structure_module_only_35M")
|
93 |
+
|
94 |
+
|
95 |
+
def esmfold_structure_module_only_35M_270K():
|
96 |
+
"""
|
97 |
+
ESMFold baseline model using 35M ESM-2, 0 folding blocks.
|
98 |
+
ESM-2 here is trained out to 270K updates.
|
99 |
+
This is a model designed to test the capabilities of the language model
|
100 |
+
when ablated for number of parameters in the language model.
|
101 |
+
See table S1 in (Lin et al, 2022).
|
102 |
+
"""
|
103 |
+
return _load_model("esmfold_structure_module_only_35M_270K")
|
104 |
+
|
105 |
+
|
106 |
+
def esmfold_structure_module_only_150M():
|
107 |
+
"""
|
108 |
+
ESMFold baseline model using 150M ESM-2, 0 folding blocks.
|
109 |
+
ESM-2 here is trained out to 500K updates.
|
110 |
+
This is a model designed to test the capabilities of the language model
|
111 |
+
when ablated for number of parameters in the language model.
|
112 |
+
See table S1 in (Lin et al, 2022).
|
113 |
+
"""
|
114 |
+
return _load_model("esmfold_structure_module_only_150M")
|
115 |
+
|
116 |
+
|
117 |
+
def esmfold_structure_module_only_150M_270K():
|
118 |
+
"""
|
119 |
+
ESMFold baseline model using 150M ESM-2, 0 folding blocks.
|
120 |
+
ESM-2 here is trained out to 270K updates.
|
121 |
+
This is a model designed to test the capabilities of the language model
|
122 |
+
when ablated for number of parameters in the language model.
|
123 |
+
See table S1 in (Lin et al, 2022).
|
124 |
+
"""
|
125 |
+
return _load_model("esmfold_structure_module_only_150M_270K")
|
126 |
+
|
127 |
+
|
128 |
+
def esmfold_structure_module_only_650M():
|
129 |
+
"""
|
130 |
+
ESMFold baseline model using 650M ESM-2, 0 folding blocks.
|
131 |
+
ESM-2 here is trained out to 500K updates.
|
132 |
+
This is a model designed to test the capabilities of the language model
|
133 |
+
when ablated for number of parameters in the language model.
|
134 |
+
See table S1 in (Lin et al, 2022).
|
135 |
+
"""
|
136 |
+
return _load_model("esmfold_structure_module_only_650M")
|
137 |
+
|
138 |
+
|
139 |
+
def esmfold_structure_module_only_650M_270K():
|
140 |
+
"""
|
141 |
+
ESMFold baseline model using 650M ESM-2, 0 folding blocks.
|
142 |
+
ESM-2 here is trained out to 270K updates.
|
143 |
+
This is a model designed to test the capabilities of the language model
|
144 |
+
when ablated for number of parameters in the language model.
|
145 |
+
See table S1 in (Lin et al, 2022).
|
146 |
+
"""
|
147 |
+
return _load_model("esmfold_structure_module_only_650M_270K")
|
148 |
+
|
149 |
+
|
150 |
+
def esmfold_structure_module_only_3B():
|
151 |
+
"""
|
152 |
+
ESMFold baseline model using 3B ESM-2, 0 folding blocks.
|
153 |
+
ESM-2 here is trained out to 500K updates.
|
154 |
+
This is a model designed to test the capabilities of the language model
|
155 |
+
when ablated for number of parameters in the language model.
|
156 |
+
See table S1 in (Lin et al, 2022).
|
157 |
+
"""
|
158 |
+
return _load_model("esmfold_structure_module_only_3B")
|
159 |
+
|
160 |
+
|
161 |
+
def esmfold_structure_module_only_3B_270K():
|
162 |
+
"""
|
163 |
+
ESMFold baseline model using 3B ESM-2, 0 folding blocks.
|
164 |
+
ESM-2 here is trained out to 270K updates.
|
165 |
+
This is a model designed to test the capabilities of the language model
|
166 |
+
when ablated for number of parameters in the language model.
|
167 |
+
See table S1 in (Lin et al, 2022).
|
168 |
+
"""
|
169 |
+
return _load_model("esmfold_structure_module_only_3B_270K")
|
170 |
+
|
171 |
+
|
172 |
+
def esmfold_structure_module_only_15B():
|
173 |
+
"""
|
174 |
+
ESMFold baseline model using 15B ESM-2, 0 folding blocks.
|
175 |
+
ESM-2 here is trained out to 270K updates.
|
176 |
+
The 15B parameter ESM-2 was not trained out to 500K updates
|
177 |
+
This is a model designed to test the capabilities of the language model
|
178 |
+
when ablated for number of parameters in the language model.
|
179 |
+
See table S1 in (Lin et al, 2022).
|
180 |
+
"""
|
181 |
+
return _load_model("esmfold_structure_module_only_15B")
|
esm/esmfold/v1/tri_self_attn_block.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import torch
|
6 |
+
from openfold.model.triangular_attention import (
|
7 |
+
TriangleAttentionEndingNode,
|
8 |
+
TriangleAttentionStartingNode,
|
9 |
+
)
|
10 |
+
from openfold.model.triangular_multiplicative_update import (
|
11 |
+
TriangleMultiplicationIncoming,
|
12 |
+
TriangleMultiplicationOutgoing,
|
13 |
+
)
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from esm.esmfold.v1.misc import (
|
17 |
+
Attention,
|
18 |
+
Dropout,
|
19 |
+
PairToSequence,
|
20 |
+
ResidueMLP,
|
21 |
+
SequenceToPair,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class TriangularSelfAttentionBlock(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
sequence_state_dim,
|
29 |
+
pairwise_state_dim,
|
30 |
+
sequence_head_width,
|
31 |
+
pairwise_head_width,
|
32 |
+
dropout=0,
|
33 |
+
**__kwargs,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
assert sequence_state_dim % sequence_head_width == 0
|
38 |
+
assert pairwise_state_dim % pairwise_head_width == 0
|
39 |
+
sequence_num_heads = sequence_state_dim // sequence_head_width
|
40 |
+
pairwise_num_heads = pairwise_state_dim // pairwise_head_width
|
41 |
+
assert sequence_state_dim == sequence_num_heads * sequence_head_width
|
42 |
+
assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width
|
43 |
+
assert pairwise_state_dim % 2 == 0
|
44 |
+
|
45 |
+
self.sequence_state_dim = sequence_state_dim
|
46 |
+
self.pairwise_state_dim = pairwise_state_dim
|
47 |
+
|
48 |
+
self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
|
49 |
+
|
50 |
+
self.sequence_to_pair = SequenceToPair(
|
51 |
+
sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim
|
52 |
+
)
|
53 |
+
self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads)
|
54 |
+
|
55 |
+
self.seq_attention = Attention(
|
56 |
+
sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True
|
57 |
+
)
|
58 |
+
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
59 |
+
pairwise_state_dim,
|
60 |
+
pairwise_state_dim,
|
61 |
+
)
|
62 |
+
self.tri_mul_in = TriangleMultiplicationIncoming(
|
63 |
+
pairwise_state_dim,
|
64 |
+
pairwise_state_dim,
|
65 |
+
)
|
66 |
+
self.tri_att_start = TriangleAttentionStartingNode(
|
67 |
+
pairwise_state_dim,
|
68 |
+
pairwise_head_width,
|
69 |
+
pairwise_num_heads,
|
70 |
+
inf=1e9,
|
71 |
+
) # type: ignore
|
72 |
+
self.tri_att_end = TriangleAttentionEndingNode(
|
73 |
+
pairwise_state_dim,
|
74 |
+
pairwise_head_width,
|
75 |
+
pairwise_num_heads,
|
76 |
+
inf=1e9,
|
77 |
+
) # type: ignore
|
78 |
+
|
79 |
+
self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout)
|
80 |
+
self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout)
|
81 |
+
|
82 |
+
assert dropout < 0.4
|
83 |
+
self.drop = nn.Dropout(dropout)
|
84 |
+
self.row_drop = Dropout(dropout * 2, 2)
|
85 |
+
self.col_drop = Dropout(dropout * 2, 1)
|
86 |
+
|
87 |
+
torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight)
|
88 |
+
torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias)
|
89 |
+
torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight)
|
90 |
+
torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias)
|
91 |
+
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight)
|
92 |
+
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias)
|
93 |
+
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight)
|
94 |
+
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias)
|
95 |
+
|
96 |
+
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight)
|
97 |
+
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias)
|
98 |
+
torch.nn.init.zeros_(self.pair_to_sequence.linear.weight)
|
99 |
+
torch.nn.init.zeros_(self.seq_attention.o_proj.weight)
|
100 |
+
torch.nn.init.zeros_(self.seq_attention.o_proj.bias)
|
101 |
+
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight)
|
102 |
+
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias)
|
103 |
+
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight)
|
104 |
+
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias)
|
105 |
+
|
106 |
+
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
|
107 |
+
"""
|
108 |
+
Inputs:
|
109 |
+
sequence_state: B x L x sequence_state_dim
|
110 |
+
pairwise_state: B x L x L x pairwise_state_dim
|
111 |
+
mask: B x L boolean tensor of valid positions
|
112 |
+
|
113 |
+
Output:
|
114 |
+
sequence_state: B x L x sequence_state_dim
|
115 |
+
pairwise_state: B x L x L x pairwise_state_dim
|
116 |
+
"""
|
117 |
+
assert len(sequence_state.shape) == 3
|
118 |
+
assert len(pairwise_state.shape) == 4
|
119 |
+
if mask is not None:
|
120 |
+
assert len(mask.shape) == 2
|
121 |
+
|
122 |
+
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
|
123 |
+
pairwise_state_dim = pairwise_state.shape[3]
|
124 |
+
assert sequence_state_dim == self.sequence_state_dim
|
125 |
+
assert pairwise_state_dim == self.pairwise_state_dim
|
126 |
+
assert batch_dim == pairwise_state.shape[0]
|
127 |
+
assert seq_dim == pairwise_state.shape[1]
|
128 |
+
assert seq_dim == pairwise_state.shape[2]
|
129 |
+
|
130 |
+
# Update sequence state
|
131 |
+
bias = self.pair_to_sequence(pairwise_state)
|
132 |
+
|
133 |
+
# Self attention with bias + mlp.
|
134 |
+
y = self.layernorm_1(sequence_state)
|
135 |
+
y, _ = self.seq_attention(y, mask=mask, bias=bias)
|
136 |
+
sequence_state = sequence_state + self.drop(y)
|
137 |
+
sequence_state = self.mlp_seq(sequence_state)
|
138 |
+
|
139 |
+
# Update pairwise state
|
140 |
+
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
|
141 |
+
|
142 |
+
# Axial attention with triangular bias.
|
143 |
+
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
|
144 |
+
pairwise_state = pairwise_state + self.row_drop(
|
145 |
+
self.tri_mul_out(pairwise_state, mask=tri_mask)
|
146 |
+
)
|
147 |
+
pairwise_state = pairwise_state + self.col_drop(
|
148 |
+
self.tri_mul_in(pairwise_state, mask=tri_mask)
|
149 |
+
)
|
150 |
+
pairwise_state = pairwise_state + self.row_drop(
|
151 |
+
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
|
152 |
+
)
|
153 |
+
pairwise_state = pairwise_state + self.col_drop(
|
154 |
+
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
|
155 |
+
)
|
156 |
+
|
157 |
+
# MLP over pairs.
|
158 |
+
pairwise_state = self.mlp_pair(pairwise_state)
|
159 |
+
|
160 |
+
return sequence_state, pairwise_state
|
esm/esmfold/v1/trunk.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import typing as T
|
6 |
+
from contextlib import ExitStack
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from openfold.model.structure_module import StructureModule
|
12 |
+
|
13 |
+
from esm.esmfold.v1.tri_self_attn_block import TriangularSelfAttentionBlock
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class StructureModuleConfig:
|
18 |
+
c_s: int = 384
|
19 |
+
c_z: int = 128
|
20 |
+
c_ipa: int = 16
|
21 |
+
c_resnet: int = 128
|
22 |
+
no_heads_ipa: int = 12
|
23 |
+
no_qk_points: int = 4
|
24 |
+
no_v_points: int = 8
|
25 |
+
dropout_rate: float = 0.1
|
26 |
+
no_blocks: int = 8
|
27 |
+
no_transition_layers: int = 1
|
28 |
+
no_resnet_blocks: int = 2
|
29 |
+
no_angles: int = 7
|
30 |
+
trans_scale_factor: int = 10
|
31 |
+
epsilon: float = 1e-8
|
32 |
+
inf: float = 1e5
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class FoldingTrunkConfig:
|
37 |
+
_name: str = "FoldingTrunkConfig"
|
38 |
+
num_blocks: int = 48
|
39 |
+
sequence_state_dim: int = 1024
|
40 |
+
pairwise_state_dim: int = 128
|
41 |
+
sequence_head_width: int = 32
|
42 |
+
pairwise_head_width: int = 32
|
43 |
+
position_bins: int = 32
|
44 |
+
dropout: float = 0
|
45 |
+
layer_drop: float = 0
|
46 |
+
cpu_grad_checkpoint: bool = False
|
47 |
+
|
48 |
+
max_recycles: int = 4
|
49 |
+
chunk_size: T.Optional[int] = None
|
50 |
+
|
51 |
+
structure_module: StructureModuleConfig = StructureModuleConfig()
|
52 |
+
|
53 |
+
|
54 |
+
def get_axial_mask(mask):
|
55 |
+
"""
|
56 |
+
Helper to convert B x L mask of valid positions to axial mask used
|
57 |
+
in row column attentions.
|
58 |
+
|
59 |
+
Input:
|
60 |
+
mask: B x L tensor of booleans
|
61 |
+
|
62 |
+
Output:
|
63 |
+
mask: B x L x L tensor of booleans
|
64 |
+
"""
|
65 |
+
|
66 |
+
if mask is None:
|
67 |
+
return None
|
68 |
+
assert len(mask.shape) == 2
|
69 |
+
batch_dim, seq_dim = mask.shape
|
70 |
+
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
|
71 |
+
m = m.reshape(batch_dim * seq_dim, seq_dim)
|
72 |
+
return m
|
73 |
+
|
74 |
+
|
75 |
+
class RelativePosition(nn.Module):
|
76 |
+
def __init__(self, bins, pairwise_state_dim):
|
77 |
+
super().__init__()
|
78 |
+
self.bins = bins
|
79 |
+
|
80 |
+
# Note an additional offset is used so that the 0th position
|
81 |
+
# is reserved for masked pairs.
|
82 |
+
self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim)
|
83 |
+
|
84 |
+
def forward(self, residue_index, mask=None):
|
85 |
+
"""
|
86 |
+
Input:
|
87 |
+
residue_index: B x L tensor of indices (dytpe=torch.long)
|
88 |
+
mask: B x L tensor of booleans
|
89 |
+
|
90 |
+
Output:
|
91 |
+
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
|
92 |
+
"""
|
93 |
+
|
94 |
+
assert residue_index.dtype == torch.long
|
95 |
+
if mask is not None:
|
96 |
+
assert residue_index.shape == mask.shape
|
97 |
+
|
98 |
+
diff = residue_index[:, None, :] - residue_index[:, :, None]
|
99 |
+
diff = diff.clamp(-self.bins, self.bins)
|
100 |
+
diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
|
101 |
+
|
102 |
+
if mask is not None:
|
103 |
+
mask = mask[:, None, :] * mask[:, :, None]
|
104 |
+
diff[mask == False] = 0
|
105 |
+
|
106 |
+
output = self.embedding(diff)
|
107 |
+
return output
|
108 |
+
|
109 |
+
|
110 |
+
class FoldingTrunk(nn.Module):
|
111 |
+
def __init__(self, **kwargs):
|
112 |
+
super().__init__()
|
113 |
+
self.cfg = FoldingTrunkConfig(**kwargs)
|
114 |
+
assert self.cfg.max_recycles > 0
|
115 |
+
|
116 |
+
c_s = self.cfg.sequence_state_dim
|
117 |
+
c_z = self.cfg.pairwise_state_dim
|
118 |
+
|
119 |
+
assert c_s % self.cfg.sequence_head_width == 0
|
120 |
+
assert c_z % self.cfg.pairwise_head_width == 0
|
121 |
+
block = TriangularSelfAttentionBlock
|
122 |
+
|
123 |
+
self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z)
|
124 |
+
|
125 |
+
self.blocks = nn.ModuleList(
|
126 |
+
[
|
127 |
+
block(
|
128 |
+
sequence_state_dim=c_s,
|
129 |
+
pairwise_state_dim=c_z,
|
130 |
+
sequence_head_width=self.cfg.sequence_head_width,
|
131 |
+
pairwise_head_width=self.cfg.pairwise_head_width,
|
132 |
+
dropout=self.cfg.dropout,
|
133 |
+
)
|
134 |
+
for i in range(self.cfg.num_blocks)
|
135 |
+
]
|
136 |
+
)
|
137 |
+
|
138 |
+
self.recycle_bins = 15
|
139 |
+
self.recycle_s_norm = nn.LayerNorm(c_s)
|
140 |
+
self.recycle_z_norm = nn.LayerNorm(c_z)
|
141 |
+
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
|
142 |
+
self.recycle_disto.weight[0].detach().zero_()
|
143 |
+
|
144 |
+
self.structure_module = StructureModule(**self.cfg.structure_module) # type: ignore
|
145 |
+
self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s)
|
146 |
+
self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z)
|
147 |
+
|
148 |
+
self.chunk_size = self.cfg.chunk_size
|
149 |
+
|
150 |
+
def set_chunk_size(self, chunk_size):
|
151 |
+
# This parameter means the axial attention will be computed
|
152 |
+
# in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
|
153 |
+
# It's equivalent to running a for loop over chunks of the dimension we're iterative over,
|
154 |
+
# where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
|
155 |
+
self.chunk_size = chunk_size
|
156 |
+
|
157 |
+
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None):
|
158 |
+
"""
|
159 |
+
Inputs:
|
160 |
+
seq_feats: B x L x C tensor of sequence features
|
161 |
+
pair_feats: B x L x L x C tensor of pair features
|
162 |
+
residx: B x L long tensor giving the position in the sequence
|
163 |
+
mask: B x L boolean tensor indicating valid residues
|
164 |
+
|
165 |
+
Output:
|
166 |
+
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
|
167 |
+
"""
|
168 |
+
|
169 |
+
device = seq_feats.device
|
170 |
+
s_s_0 = seq_feats
|
171 |
+
s_z_0 = pair_feats
|
172 |
+
|
173 |
+
if no_recycles is None:
|
174 |
+
no_recycles = self.cfg.max_recycles
|
175 |
+
else:
|
176 |
+
assert no_recycles >= 0, "Number of recycles must not be negative."
|
177 |
+
no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
|
178 |
+
|
179 |
+
def trunk_iter(s, z, residx, mask):
|
180 |
+
z = z + self.pairwise_positional_embedding(residx, mask=mask)
|
181 |
+
|
182 |
+
for block in self.blocks:
|
183 |
+
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
|
184 |
+
return s, z
|
185 |
+
|
186 |
+
s_s = s_s_0
|
187 |
+
s_z = s_z_0
|
188 |
+
recycle_s = torch.zeros_like(s_s)
|
189 |
+
recycle_z = torch.zeros_like(s_z)
|
190 |
+
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
|
191 |
+
|
192 |
+
assert no_recycles > 0
|
193 |
+
for recycle_idx in range(no_recycles):
|
194 |
+
with ExitStack() if recycle_idx == no_recycles - 1 else torch.no_grad():
|
195 |
+
# === Recycling ===
|
196 |
+
recycle_s = self.recycle_s_norm(recycle_s.detach())
|
197 |
+
recycle_z = self.recycle_z_norm(recycle_z.detach())
|
198 |
+
recycle_z += self.recycle_disto(recycle_bins.detach())
|
199 |
+
|
200 |
+
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
|
201 |
+
|
202 |
+
# === Structure module ===
|
203 |
+
structure = self.structure_module(
|
204 |
+
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
|
205 |
+
true_aa,
|
206 |
+
mask.float(),
|
207 |
+
)
|
208 |
+
|
209 |
+
recycle_s = s_s
|
210 |
+
recycle_z = s_z
|
211 |
+
# Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
|
212 |
+
recycle_bins = FoldingTrunk.distogram(
|
213 |
+
structure["positions"][-1][:, :, :3],
|
214 |
+
3.375,
|
215 |
+
21.375,
|
216 |
+
self.recycle_bins,
|
217 |
+
)
|
218 |
+
|
219 |
+
assert isinstance(structure, dict) # type: ignore
|
220 |
+
structure["s_s"] = s_s
|
221 |
+
structure["s_z"] = s_z
|
222 |
+
|
223 |
+
return structure
|
224 |
+
|
225 |
+
@staticmethod
|
226 |
+
def distogram(coords, min_bin, max_bin, num_bins):
|
227 |
+
# Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
|
228 |
+
boundaries = torch.linspace(
|
229 |
+
min_bin,
|
230 |
+
max_bin,
|
231 |
+
num_bins - 1,
|
232 |
+
device=coords.device,
|
233 |
+
)
|
234 |
+
boundaries = boundaries**2
|
235 |
+
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
|
236 |
+
# Infer CB coordinates.
|
237 |
+
b = CA - N
|
238 |
+
c = C - CA
|
239 |
+
a = b.cross(c, dim=-1)
|
240 |
+
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
|
241 |
+
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
|
242 |
+
bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
|
243 |
+
return bins
|
esm/inverse_folding/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import gvp_transformer
|
7 |
+
from . import util
|
8 |
+
from . import multichain_util
|
esm/inverse_folding/features.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
#
|
6 |
+
# Portions of this file were adapted from the open source code for the following
|
7 |
+
# two papers:
|
8 |
+
#
|
9 |
+
# Ingraham, J., Garg, V., Barzilay, R., & Jaakkola, T. (2019). Generative
|
10 |
+
# models for graph-based protein design. Advances in Neural Information
|
11 |
+
# Processing Systems, 32.
|
12 |
+
#
|
13 |
+
# Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
|
14 |
+
# Learning from Protein Structure with Geometric Vector Perceptrons. In
|
15 |
+
# International Conference on Learning Representations.
|
16 |
+
#
|
17 |
+
# MIT License
|
18 |
+
#
|
19 |
+
# Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
#
|
39 |
+
# ================================================================
|
40 |
+
# The below license applies to the portions of the code (parts of
|
41 |
+
# src/datasets.py and src/models.py) adapted from Ingraham, et al.
|
42 |
+
# ================================================================
|
43 |
+
#
|
44 |
+
# MIT License
|
45 |
+
#
|
46 |
+
# Copyright (c) 2019 John Ingraham, Vikas Garg, Regina Barzilay, Tommi Jaakkola
|
47 |
+
#
|
48 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
49 |
+
# of this software and associated documentation files (the "Software"), to deal
|
50 |
+
# in the Software without restriction, including without limitation the rights
|
51 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
52 |
+
# copies of the Software, and to permit persons to whom the Software is
|
53 |
+
# furnished to do so, subject to the following conditions:
|
54 |
+
#
|
55 |
+
# The above copyright notice and this permission notice shall be included in all
|
56 |
+
# copies or substantial portions of the Software.
|
57 |
+
#
|
58 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
59 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
60 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
61 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
62 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
63 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
64 |
+
# SOFTWARE.
|
65 |
+
|
66 |
+
import math
|
67 |
+
import numpy as np
|
68 |
+
import torch
|
69 |
+
import torch.nn as nn
|
70 |
+
import torch.nn.functional as F
|
71 |
+
|
72 |
+
from .gvp_utils import flatten_graph
|
73 |
+
from .gvp_modules import GVP, LayerNorm
|
74 |
+
from .util import normalize, norm, nan_to_num, rbf
|
75 |
+
|
76 |
+
|
77 |
+
class GVPInputFeaturizer(nn.Module):
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def get_node_features(coords, coord_mask, with_coord_mask=True):
|
81 |
+
# scalar features
|
82 |
+
node_scalar_features = GVPInputFeaturizer._dihedrals(coords)
|
83 |
+
if with_coord_mask:
|
84 |
+
node_scalar_features = torch.cat([
|
85 |
+
node_scalar_features,
|
86 |
+
coord_mask.float().unsqueeze(-1)
|
87 |
+
], dim=-1)
|
88 |
+
# vector features
|
89 |
+
X_ca = coords[:, :, 1]
|
90 |
+
orientations = GVPInputFeaturizer._orientations(X_ca)
|
91 |
+
sidechains = GVPInputFeaturizer._sidechains(coords)
|
92 |
+
node_vector_features = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
|
93 |
+
return node_scalar_features, node_vector_features
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def _orientations(X):
|
97 |
+
forward = normalize(X[:, 1:] - X[:, :-1])
|
98 |
+
backward = normalize(X[:, :-1] - X[:, 1:])
|
99 |
+
forward = F.pad(forward, [0, 0, 0, 1])
|
100 |
+
backward = F.pad(backward, [0, 0, 1, 0])
|
101 |
+
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def _sidechains(X):
|
105 |
+
n, origin, c = X[:, :, 0], X[:, :, 1], X[:, :, 2]
|
106 |
+
c, n = normalize(c - origin), normalize(n - origin)
|
107 |
+
bisector = normalize(c + n)
|
108 |
+
perp = normalize(torch.cross(c, n, dim=-1))
|
109 |
+
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
|
110 |
+
return vec
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def _dihedrals(X, eps=1e-7):
|
114 |
+
X = torch.flatten(X[:, :, :3], 1, 2)
|
115 |
+
bsz = X.shape[0]
|
116 |
+
dX = X[:, 1:] - X[:, :-1]
|
117 |
+
U = normalize(dX, dim=-1)
|
118 |
+
u_2 = U[:, :-2]
|
119 |
+
u_1 = U[:, 1:-1]
|
120 |
+
u_0 = U[:, 2:]
|
121 |
+
|
122 |
+
# Backbone normals
|
123 |
+
n_2 = normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
|
124 |
+
n_1 = normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
|
125 |
+
|
126 |
+
# Angle between normals
|
127 |
+
cosD = torch.sum(n_2 * n_1, -1)
|
128 |
+
cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
|
129 |
+
D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
|
130 |
+
|
131 |
+
# This scheme will remove phi[0], psi[-1], omega[-1]
|
132 |
+
D = F.pad(D, [1, 2])
|
133 |
+
D = torch.reshape(D, [bsz, -1, 3])
|
134 |
+
# Lift angle representations to the circle
|
135 |
+
D_features = torch.cat([torch.cos(D), torch.sin(D)], -1)
|
136 |
+
return D_features
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def _positional_embeddings(edge_index,
|
140 |
+
num_embeddings=None,
|
141 |
+
num_positional_embeddings=16,
|
142 |
+
period_range=[2, 1000]):
|
143 |
+
# From https://github.com/jingraham/neurips19-graph-protein-design
|
144 |
+
num_embeddings = num_embeddings or num_positional_embeddings
|
145 |
+
d = edge_index[0] - edge_index[1]
|
146 |
+
|
147 |
+
frequency = torch.exp(
|
148 |
+
torch.arange(0, num_embeddings, 2, dtype=torch.float32,
|
149 |
+
device=edge_index.device)
|
150 |
+
* -(np.log(10000.0) / num_embeddings)
|
151 |
+
)
|
152 |
+
angles = d.unsqueeze(-1) * frequency
|
153 |
+
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
|
154 |
+
return E
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def _dist(X, coord_mask, padding_mask, top_k_neighbors, eps=1e-8):
|
158 |
+
""" Pairwise euclidean distances """
|
159 |
+
bsz, maxlen = X.size(0), X.size(1)
|
160 |
+
coord_mask_2D = torch.unsqueeze(coord_mask,1) * torch.unsqueeze(coord_mask,2)
|
161 |
+
residue_mask = ~padding_mask
|
162 |
+
residue_mask_2D = torch.unsqueeze(residue_mask,1) * torch.unsqueeze(residue_mask,2)
|
163 |
+
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
|
164 |
+
D = coord_mask_2D * norm(dX, dim=-1)
|
165 |
+
|
166 |
+
# sorting preference: first those with coords, then among the residues that
|
167 |
+
# exist but are masked use distance in sequence as tie breaker, and then the
|
168 |
+
# residues that came from padding are last
|
169 |
+
seqpos = torch.arange(maxlen, device=X.device)
|
170 |
+
Dseq = torch.abs(seqpos.unsqueeze(1) - seqpos.unsqueeze(0)).repeat(bsz, 1, 1)
|
171 |
+
D_adjust = nan_to_num(D) + (~coord_mask_2D) * (1e8 + Dseq*1e6) + (
|
172 |
+
~residue_mask_2D) * (1e10)
|
173 |
+
|
174 |
+
if top_k_neighbors == -1:
|
175 |
+
D_neighbors = D_adjust
|
176 |
+
E_idx = seqpos.repeat(
|
177 |
+
*D_neighbors.shape[:-1], 1)
|
178 |
+
else:
|
179 |
+
# Identify k nearest neighbors (including self)
|
180 |
+
k = min(top_k_neighbors, X.size(1))
|
181 |
+
D_neighbors, E_idx = torch.topk(D_adjust, k, dim=-1, largest=False)
|
182 |
+
|
183 |
+
coord_mask_neighbors = (D_neighbors < 5e7)
|
184 |
+
residue_mask_neighbors = (D_neighbors < 5e9)
|
185 |
+
return D_neighbors, E_idx, coord_mask_neighbors, residue_mask_neighbors
|
186 |
+
|
187 |
+
|
188 |
+
class Normalize(nn.Module):
|
189 |
+
def __init__(self, features, epsilon=1e-6):
|
190 |
+
super(Normalize, self).__init__()
|
191 |
+
self.gain = nn.Parameter(torch.ones(features))
|
192 |
+
self.bias = nn.Parameter(torch.zeros(features))
|
193 |
+
self.epsilon = epsilon
|
194 |
+
|
195 |
+
def forward(self, x, dim=-1):
|
196 |
+
mu = x.mean(dim, keepdim=True)
|
197 |
+
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
|
198 |
+
gain = self.gain
|
199 |
+
bias = self.bias
|
200 |
+
# Reshape
|
201 |
+
if dim != -1:
|
202 |
+
shape = [1] * len(mu.size())
|
203 |
+
shape[dim] = self.gain.size()[0]
|
204 |
+
gain = gain.view(shape)
|
205 |
+
bias = bias.view(shape)
|
206 |
+
return gain * (x - mu) / (sigma + self.epsilon) + bias
|
207 |
+
|
208 |
+
|
209 |
+
class DihedralFeatures(nn.Module):
|
210 |
+
def __init__(self, node_embed_dim):
|
211 |
+
""" Embed dihedral angle features. """
|
212 |
+
super(DihedralFeatures, self).__init__()
|
213 |
+
# 3 dihedral angles; sin and cos of each angle
|
214 |
+
node_in = 6
|
215 |
+
# Normalization and embedding
|
216 |
+
self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True)
|
217 |
+
self.norm_nodes = Normalize(node_embed_dim)
|
218 |
+
|
219 |
+
def forward(self, X):
|
220 |
+
""" Featurize coordinates as an attributed graph """
|
221 |
+
V = self._dihedrals(X)
|
222 |
+
V = self.node_embedding(V)
|
223 |
+
V = self.norm_nodes(V)
|
224 |
+
return V
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def _dihedrals(X, eps=1e-7, return_angles=False):
|
228 |
+
# First 3 coordinates are N, CA, C
|
229 |
+
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
|
230 |
+
|
231 |
+
# Shifted slices of unit vectors
|
232 |
+
dX = X[:,1:,:] - X[:,:-1,:]
|
233 |
+
U = F.normalize(dX, dim=-1)
|
234 |
+
u_2 = U[:,:-2,:]
|
235 |
+
u_1 = U[:,1:-1,:]
|
236 |
+
u_0 = U[:,2:,:]
|
237 |
+
# Backbone normals
|
238 |
+
n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
|
239 |
+
n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
|
240 |
+
|
241 |
+
# Angle between normals
|
242 |
+
cosD = (n_2 * n_1).sum(-1)
|
243 |
+
cosD = torch.clamp(cosD, -1+eps, 1-eps)
|
244 |
+
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
|
245 |
+
|
246 |
+
# This scheme will remove phi[0], psi[-1], omega[-1]
|
247 |
+
D = F.pad(D, (1,2), 'constant', 0)
|
248 |
+
D = D.view((D.size(0), int(D.size(1)/3), 3))
|
249 |
+
phi, psi, omega = torch.unbind(D,-1)
|
250 |
+
|
251 |
+
if return_angles:
|
252 |
+
return phi, psi, omega
|
253 |
+
|
254 |
+
# Lift angle representations to the circle
|
255 |
+
D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
|
256 |
+
return D_features
|
257 |
+
|
258 |
+
|
259 |
+
class GVPGraphEmbedding(GVPInputFeaturizer):
|
260 |
+
|
261 |
+
def __init__(self, args):
|
262 |
+
super().__init__()
|
263 |
+
self.top_k_neighbors = args.top_k_neighbors
|
264 |
+
self.num_positional_embeddings = 16
|
265 |
+
self.remove_edges_without_coords = True
|
266 |
+
node_input_dim = (7, 3)
|
267 |
+
edge_input_dim = (34, 1)
|
268 |
+
node_hidden_dim = (args.node_hidden_dim_scalar,
|
269 |
+
args.node_hidden_dim_vector)
|
270 |
+
edge_hidden_dim = (args.edge_hidden_dim_scalar,
|
271 |
+
args.edge_hidden_dim_vector)
|
272 |
+
self.embed_node = nn.Sequential(
|
273 |
+
GVP(node_input_dim, node_hidden_dim, activations=(None, None)),
|
274 |
+
LayerNorm(node_hidden_dim, eps=1e-4)
|
275 |
+
)
|
276 |
+
self.embed_edge = nn.Sequential(
|
277 |
+
GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)),
|
278 |
+
LayerNorm(edge_hidden_dim, eps=1e-4)
|
279 |
+
)
|
280 |
+
self.embed_confidence = nn.Linear(16, args.node_hidden_dim_scalar)
|
281 |
+
|
282 |
+
def forward(self, coords, coord_mask, padding_mask, confidence):
|
283 |
+
with torch.no_grad():
|
284 |
+
node_features = self.get_node_features(coords, coord_mask)
|
285 |
+
edge_features, edge_index = self.get_edge_features(
|
286 |
+
coords, coord_mask, padding_mask)
|
287 |
+
node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features)
|
288 |
+
edge_embeddings = self.embed_edge(edge_features)
|
289 |
+
|
290 |
+
rbf_rep = rbf(confidence, 0., 1.)
|
291 |
+
node_embeddings = (
|
292 |
+
node_embeddings_scalar + self.embed_confidence(rbf_rep),
|
293 |
+
node_embeddings_vector
|
294 |
+
)
|
295 |
+
|
296 |
+
node_embeddings, edge_embeddings, edge_index = flatten_graph(
|
297 |
+
node_embeddings, edge_embeddings, edge_index)
|
298 |
+
return node_embeddings, edge_embeddings, edge_index
|
299 |
+
|
300 |
+
def get_edge_features(self, coords, coord_mask, padding_mask):
|
301 |
+
X_ca = coords[:, :, 1]
|
302 |
+
# Get distances to the top k neighbors
|
303 |
+
E_dist, E_idx, E_coord_mask, E_residue_mask = GVPInputFeaturizer._dist(
|
304 |
+
X_ca, coord_mask, padding_mask, self.top_k_neighbors)
|
305 |
+
# Flatten the graph to be batch size 1 for torch_geometric package
|
306 |
+
dest = E_idx
|
307 |
+
B, L, k = E_idx.shape[:3]
|
308 |
+
src = torch.arange(L, device=E_idx.device).view([1, L, 1]).expand(B, L, k)
|
309 |
+
# After flattening, [2, B, E]
|
310 |
+
edge_index = torch.stack([src, dest], dim=0).flatten(2, 3)
|
311 |
+
# After flattening, [B, E]
|
312 |
+
E_dist = E_dist.flatten(1, 2)
|
313 |
+
E_coord_mask = E_coord_mask.flatten(1, 2).unsqueeze(-1)
|
314 |
+
E_residue_mask = E_residue_mask.flatten(1, 2)
|
315 |
+
# Calculate relative positional embeddings and distance RBF
|
316 |
+
pos_embeddings = GVPInputFeaturizer._positional_embeddings(
|
317 |
+
edge_index,
|
318 |
+
num_positional_embeddings=self.num_positional_embeddings,
|
319 |
+
)
|
320 |
+
D_rbf = rbf(E_dist, 0., 20.)
|
321 |
+
# Calculate relative orientation
|
322 |
+
X_src = X_ca.unsqueeze(2).expand(-1, -1, k, -1).flatten(1, 2)
|
323 |
+
X_dest = torch.gather(
|
324 |
+
X_ca,
|
325 |
+
1,
|
326 |
+
edge_index[1, :, :].unsqueeze(-1).expand([B, L*k, 3])
|
327 |
+
)
|
328 |
+
coord_mask_src = coord_mask.unsqueeze(2).expand(-1, -1, k).flatten(1, 2)
|
329 |
+
coord_mask_dest = torch.gather(
|
330 |
+
coord_mask,
|
331 |
+
1,
|
332 |
+
edge_index[1, :, :].expand([B, L*k])
|
333 |
+
)
|
334 |
+
E_vectors = X_src - X_dest
|
335 |
+
# For the ones without coordinates, substitute in the average vector
|
336 |
+
E_vector_mean = torch.sum(E_vectors * E_coord_mask, dim=1,
|
337 |
+
keepdims=True) / torch.sum(E_coord_mask, dim=1, keepdims=True)
|
338 |
+
E_vectors = E_vectors * E_coord_mask + E_vector_mean * ~(E_coord_mask)
|
339 |
+
# Normalize and remove nans
|
340 |
+
edge_s = torch.cat([D_rbf, pos_embeddings], dim=-1)
|
341 |
+
edge_v = normalize(E_vectors).unsqueeze(-2)
|
342 |
+
edge_s, edge_v = map(nan_to_num, (edge_s, edge_v))
|
343 |
+
# Also add indications of whether the coordinates are present
|
344 |
+
edge_s = torch.cat([
|
345 |
+
edge_s,
|
346 |
+
(~coord_mask_src).float().unsqueeze(-1),
|
347 |
+
(~coord_mask_dest).float().unsqueeze(-1),
|
348 |
+
], dim=-1)
|
349 |
+
edge_index[:, ~E_residue_mask] = -1
|
350 |
+
if self.remove_edges_without_coords:
|
351 |
+
edge_index[:, ~E_coord_mask.squeeze(-1)] = -1
|
352 |
+
return (edge_s, edge_v), edge_index.transpose(0, 1)
|
esm/inverse_folding/gvp_encoder.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from argparse import Namespace
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .features import GVPGraphEmbedding
|
13 |
+
from .gvp_modules import GVPConvLayer, LayerNorm
|
14 |
+
from .gvp_utils import unflatten_graph
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class GVPEncoder(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, args):
|
21 |
+
super().__init__()
|
22 |
+
self.args = args
|
23 |
+
self.embed_graph = GVPGraphEmbedding(args)
|
24 |
+
|
25 |
+
node_hidden_dim = (args.node_hidden_dim_scalar,
|
26 |
+
args.node_hidden_dim_vector)
|
27 |
+
edge_hidden_dim = (args.edge_hidden_dim_scalar,
|
28 |
+
args.edge_hidden_dim_vector)
|
29 |
+
|
30 |
+
conv_activations = (F.relu, torch.sigmoid)
|
31 |
+
self.encoder_layers = nn.ModuleList(
|
32 |
+
GVPConvLayer(
|
33 |
+
node_hidden_dim,
|
34 |
+
edge_hidden_dim,
|
35 |
+
drop_rate=args.dropout,
|
36 |
+
vector_gate=True,
|
37 |
+
attention_heads=0,
|
38 |
+
n_message=3,
|
39 |
+
conv_activations=conv_activations,
|
40 |
+
n_edge_gvps=0,
|
41 |
+
eps=1e-4,
|
42 |
+
layernorm=True,
|
43 |
+
)
|
44 |
+
for i in range(args.num_encoder_layers)
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, coords, coord_mask, padding_mask, confidence):
|
48 |
+
node_embeddings, edge_embeddings, edge_index = self.embed_graph(
|
49 |
+
coords, coord_mask, padding_mask, confidence)
|
50 |
+
|
51 |
+
for i, layer in enumerate(self.encoder_layers):
|
52 |
+
node_embeddings, edge_embeddings = layer(node_embeddings,
|
53 |
+
edge_index, edge_embeddings)
|
54 |
+
|
55 |
+
node_embeddings = unflatten_graph(node_embeddings, coords.shape[0])
|
56 |
+
return node_embeddings
|
esm/inverse_folding/gvp_modules.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contents of this file are from the open source code for
|
2 |
+
#
|
3 |
+
# Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
|
4 |
+
# Learning from Protein Structure with Geometric Vector Perceptrons. In
|
5 |
+
# International Conference on Learning Representations.
|
6 |
+
#
|
7 |
+
# MIT License
|
8 |
+
#
|
9 |
+
# Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
|
10 |
+
#
|
11 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
12 |
+
# of this software and associated documentation files (the "Software"), to deal
|
13 |
+
# in the Software without restriction, including without limitation the rights
|
14 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
15 |
+
# copies of the Software, and to permit persons to whom the Software is
|
16 |
+
# furnished to do so, subject to the following conditions:
|
17 |
+
#
|
18 |
+
# The above copyright notice and this permission notice shall be included in all
|
19 |
+
# copies or substantial portions of the Software.
|
20 |
+
#
|
21 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
22 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
23 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
24 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
25 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
26 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
27 |
+
# SOFTWARE.
|
28 |
+
|
29 |
+
import typing as T
|
30 |
+
import torch
|
31 |
+
from torch import nn
|
32 |
+
import torch.nn.functional as F
|
33 |
+
from torch_geometric.nn import MessagePassing
|
34 |
+
|
35 |
+
def tuple_size(tp):
|
36 |
+
return tuple([0 if a is None else a.size() for a in tp])
|
37 |
+
|
38 |
+
def tuple_sum(tp1, tp2):
|
39 |
+
s1, v1 = tp1
|
40 |
+
s2, v2 = tp2
|
41 |
+
if v2 is None and v2 is None:
|
42 |
+
return (s1 + s2, None)
|
43 |
+
return (s1 + s2, v1 + v2)
|
44 |
+
|
45 |
+
def tuple_cat(*args, dim=-1):
|
46 |
+
'''
|
47 |
+
Concatenates any number of tuples (s, V) elementwise.
|
48 |
+
|
49 |
+
:param dim: dimension along which to concatenate when viewed
|
50 |
+
as the `dim` index for the scalar-channel tensors.
|
51 |
+
This means that `dim=-1` will be applied as
|
52 |
+
`dim=-2` for the vector-channel tensors.
|
53 |
+
'''
|
54 |
+
dim %= len(args[0][0].shape)
|
55 |
+
s_args, v_args = list(zip(*args))
|
56 |
+
return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
|
57 |
+
|
58 |
+
def tuple_index(x, idx):
|
59 |
+
'''
|
60 |
+
Indexes into a tuple (s, V) along the first dimension.
|
61 |
+
|
62 |
+
:param idx: any object which can be used to index into a `torch.Tensor`
|
63 |
+
'''
|
64 |
+
return x[0][idx], x[1][idx]
|
65 |
+
|
66 |
+
def randn(n, dims, device="cpu"):
|
67 |
+
'''
|
68 |
+
Returns random tuples (s, V) drawn elementwise from a normal distribution.
|
69 |
+
|
70 |
+
:param n: number of data points
|
71 |
+
:param dims: tuple of dimensions (n_scalar, n_vector)
|
72 |
+
|
73 |
+
:return: (s, V) with s.shape = (n, n_scalar) and
|
74 |
+
V.shape = (n, n_vector, 3)
|
75 |
+
'''
|
76 |
+
return torch.randn(n, dims[0], device=device), \
|
77 |
+
torch.randn(n, dims[1], 3, device=device)
|
78 |
+
|
79 |
+
def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
|
80 |
+
'''
|
81 |
+
L2 norm of tensor clamped above a minimum value `eps`.
|
82 |
+
|
83 |
+
:param sqrt: if `False`, returns the square of the L2 norm
|
84 |
+
'''
|
85 |
+
# clamp is slow
|
86 |
+
# out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
|
87 |
+
out = torch.sum(torch.square(x), axis, keepdims) + eps
|
88 |
+
return torch.sqrt(out) if sqrt else out
|
89 |
+
|
90 |
+
def _split(x, nv):
|
91 |
+
'''
|
92 |
+
Splits a merged representation of (s, V) back into a tuple.
|
93 |
+
Should be used only with `_merge(s, V)` and only if the tuple
|
94 |
+
representation cannot be used.
|
95 |
+
|
96 |
+
:param x: the `torch.Tensor` returned from `_merge`
|
97 |
+
:param nv: the number of vector channels in the input to `_merge`
|
98 |
+
'''
|
99 |
+
v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
|
100 |
+
s = x[..., :-3*nv]
|
101 |
+
return s, v
|
102 |
+
|
103 |
+
def _merge(s, v):
|
104 |
+
'''
|
105 |
+
Merges a tuple (s, V) into a single `torch.Tensor`, where the
|
106 |
+
vector channels are flattened and appended to the scalar channels.
|
107 |
+
Should be used only if the tuple representation cannot be used.
|
108 |
+
Use `_split(x, nv)` to reverse.
|
109 |
+
'''
|
110 |
+
v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
|
111 |
+
return torch.cat([s, v], -1)
|
112 |
+
|
113 |
+
class GVP(nn.Module):
|
114 |
+
'''
|
115 |
+
Geometric Vector Perceptron. See manuscript and README.md
|
116 |
+
for more details.
|
117 |
+
|
118 |
+
:param in_dims: tuple (n_scalar, n_vector)
|
119 |
+
:param out_dims: tuple (n_scalar, n_vector)
|
120 |
+
:param h_dim: intermediate number of vector channels, optional
|
121 |
+
:param activations: tuple of functions (scalar_act, vector_act)
|
122 |
+
:param tuple_io: whether to keep accepting tuple inputs and outputs when vi
|
123 |
+
or vo = 0
|
124 |
+
'''
|
125 |
+
def __init__(self, in_dims, out_dims, h_dim=None, vector_gate=False,
|
126 |
+
activations=(F.relu, torch.sigmoid), tuple_io=True,
|
127 |
+
eps=1e-8):
|
128 |
+
super(GVP, self).__init__()
|
129 |
+
self.si, self.vi = in_dims
|
130 |
+
self.so, self.vo = out_dims
|
131 |
+
self.tuple_io = tuple_io
|
132 |
+
if self.vi:
|
133 |
+
self.h_dim = h_dim or max(self.vi, self.vo)
|
134 |
+
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
|
135 |
+
self.ws = nn.Linear(self.h_dim + self.si, self.so)
|
136 |
+
if self.vo:
|
137 |
+
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
|
138 |
+
if vector_gate:
|
139 |
+
self.wg = nn.Linear(self.so, self.vo)
|
140 |
+
else:
|
141 |
+
self.ws = nn.Linear(self.si, self.so)
|
142 |
+
|
143 |
+
self.vector_gate = vector_gate
|
144 |
+
self.scalar_act, self.vector_act = activations
|
145 |
+
self.eps = eps
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
'''
|
149 |
+
:param x: tuple (s, V) of `torch.Tensor`,
|
150 |
+
or (if vectors_in is 0), a single `torch.Tensor`
|
151 |
+
:return: tuple (s, V) of `torch.Tensor`,
|
152 |
+
or (if vectors_out is 0), a single `torch.Tensor`
|
153 |
+
'''
|
154 |
+
if self.vi:
|
155 |
+
s, v = x
|
156 |
+
v = torch.transpose(v, -1, -2)
|
157 |
+
vh = self.wh(v)
|
158 |
+
vn = _norm_no_nan(vh, axis=-2, eps=self.eps)
|
159 |
+
s = self.ws(torch.cat([s, vn], -1))
|
160 |
+
if self.scalar_act:
|
161 |
+
s = self.scalar_act(s)
|
162 |
+
if self.vo:
|
163 |
+
v = self.wv(vh)
|
164 |
+
v = torch.transpose(v, -1, -2)
|
165 |
+
if self.vector_gate:
|
166 |
+
g = self.wg(s).unsqueeze(-1)
|
167 |
+
else:
|
168 |
+
g = _norm_no_nan(v, axis=-1, keepdims=True, eps=self.eps)
|
169 |
+
if self.vector_act:
|
170 |
+
g = self.vector_act(g)
|
171 |
+
v = v * g
|
172 |
+
else:
|
173 |
+
if self.tuple_io:
|
174 |
+
assert x[1] is None
|
175 |
+
x = x[0]
|
176 |
+
s = self.ws(x)
|
177 |
+
if self.scalar_act:
|
178 |
+
s = self.scalar_act(s)
|
179 |
+
if self.vo:
|
180 |
+
v = torch.zeros(list(s.shape)[:-1] + [self.vo, 3],
|
181 |
+
device=s.device)
|
182 |
+
|
183 |
+
if self.vo:
|
184 |
+
return (s, v)
|
185 |
+
elif self.tuple_io:
|
186 |
+
return (s, None)
|
187 |
+
else:
|
188 |
+
return s
|
189 |
+
|
190 |
+
|
191 |
+
class _VDropout(nn.Module):
|
192 |
+
'''
|
193 |
+
Vector channel dropout where the elements of each
|
194 |
+
vector channel are dropped together.
|
195 |
+
'''
|
196 |
+
def __init__(self, drop_rate):
|
197 |
+
super(_VDropout, self).__init__()
|
198 |
+
self.drop_rate = drop_rate
|
199 |
+
|
200 |
+
def forward(self, x):
|
201 |
+
'''
|
202 |
+
:param x: `torch.Tensor` corresponding to vector channels
|
203 |
+
'''
|
204 |
+
if x is None:
|
205 |
+
return None
|
206 |
+
device = x.device
|
207 |
+
if not self.training:
|
208 |
+
return x
|
209 |
+
mask = torch.bernoulli(
|
210 |
+
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
|
211 |
+
).unsqueeze(-1)
|
212 |
+
x = mask * x / (1 - self.drop_rate)
|
213 |
+
return x
|
214 |
+
|
215 |
+
class Dropout(nn.Module):
|
216 |
+
'''
|
217 |
+
Combined dropout for tuples (s, V).
|
218 |
+
Takes tuples (s, V) as input and as output.
|
219 |
+
'''
|
220 |
+
def __init__(self, drop_rate):
|
221 |
+
super(Dropout, self).__init__()
|
222 |
+
self.sdropout = nn.Dropout(drop_rate)
|
223 |
+
self.vdropout = _VDropout(drop_rate)
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
'''
|
227 |
+
:param x: tuple (s, V) of `torch.Tensor`,
|
228 |
+
or single `torch.Tensor`
|
229 |
+
(will be assumed to be scalar channels)
|
230 |
+
'''
|
231 |
+
if type(x) is torch.Tensor:
|
232 |
+
return self.sdropout(x)
|
233 |
+
s, v = x
|
234 |
+
return self.sdropout(s), self.vdropout(v)
|
235 |
+
|
236 |
+
class LayerNorm(nn.Module):
|
237 |
+
'''
|
238 |
+
Combined LayerNorm for tuples (s, V).
|
239 |
+
Takes tuples (s, V) as input and as output.
|
240 |
+
'''
|
241 |
+
def __init__(self, dims, tuple_io=True, eps=1e-8):
|
242 |
+
super(LayerNorm, self).__init__()
|
243 |
+
self.tuple_io = tuple_io
|
244 |
+
self.s, self.v = dims
|
245 |
+
self.scalar_norm = nn.LayerNorm(self.s)
|
246 |
+
self.eps = eps
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
'''
|
250 |
+
:param x: tuple (s, V) of `torch.Tensor`,
|
251 |
+
or single `torch.Tensor`
|
252 |
+
(will be assumed to be scalar channels)
|
253 |
+
'''
|
254 |
+
if not self.v:
|
255 |
+
if self.tuple_io:
|
256 |
+
return self.scalar_norm(x[0]), None
|
257 |
+
return self.scalar_norm(x)
|
258 |
+
s, v = x
|
259 |
+
vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False, eps=self.eps)
|
260 |
+
nonzero_mask = (vn > 2 * self.eps)
|
261 |
+
vn = torch.sum(vn * nonzero_mask, dim=-2, keepdim=True
|
262 |
+
) / (self.eps + torch.sum(nonzero_mask, dim=-2, keepdim=True))
|
263 |
+
vn = torch.sqrt(vn + self.eps)
|
264 |
+
v = nonzero_mask * (v / vn)
|
265 |
+
return self.scalar_norm(s), v
|
266 |
+
|
267 |
+
class GVPConv(MessagePassing):
|
268 |
+
'''
|
269 |
+
Graph convolution / message passing with Geometric Vector Perceptrons.
|
270 |
+
Takes in a graph with node and edge embeddings,
|
271 |
+
and returns new node embeddings.
|
272 |
+
|
273 |
+
This does NOT do residual updates and pointwise feedforward layers
|
274 |
+
---see `GVPConvLayer`.
|
275 |
+
|
276 |
+
:param in_dims: input node embedding dimensions (n_scalar, n_vector)
|
277 |
+
:param out_dims: output node embedding dimensions (n_scalar, n_vector)
|
278 |
+
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
|
279 |
+
:param n_layers: number of GVPs in the message function
|
280 |
+
:param module_list: preconstructed message function, overrides n_layers
|
281 |
+
:param aggr: should be "add" if some incoming edges are masked, as in
|
282 |
+
a masked autoregressive decoder architecture
|
283 |
+
'''
|
284 |
+
def __init__(self, in_dims, out_dims, edge_dims, n_layers=3,
|
285 |
+
vector_gate=False, module_list=None, aggr="mean", eps=1e-8,
|
286 |
+
activations=(F.relu, torch.sigmoid)):
|
287 |
+
super(GVPConv, self).__init__(aggr=aggr)
|
288 |
+
self.eps = eps
|
289 |
+
self.si, self.vi = in_dims
|
290 |
+
self.so, self.vo = out_dims
|
291 |
+
self.se, self.ve = edge_dims
|
292 |
+
|
293 |
+
module_list = module_list or []
|
294 |
+
if not module_list:
|
295 |
+
if n_layers == 1:
|
296 |
+
module_list.append(
|
297 |
+
GVP((2*self.si + self.se, 2*self.vi + self.ve),
|
298 |
+
(self.so, self.vo), activations=(None, None)))
|
299 |
+
else:
|
300 |
+
module_list.append(
|
301 |
+
GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims,
|
302 |
+
vector_gate=vector_gate, activations=activations)
|
303 |
+
)
|
304 |
+
for i in range(n_layers - 2):
|
305 |
+
module_list.append(GVP(out_dims, out_dims,
|
306 |
+
vector_gate=vector_gate))
|
307 |
+
module_list.append(GVP(out_dims, out_dims,
|
308 |
+
activations=(None, None)))
|
309 |
+
self.message_func = nn.Sequential(*module_list)
|
310 |
+
|
311 |
+
def forward(self, x, edge_index, edge_attr):
|
312 |
+
'''
|
313 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
314 |
+
:param edge_index: array of shape [2, n_edges]
|
315 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
316 |
+
'''
|
317 |
+
x_s, x_v = x
|
318 |
+
message = self.propagate(edge_index,
|
319 |
+
s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
|
320 |
+
edge_attr=edge_attr)
|
321 |
+
return _split(message, self.vo)
|
322 |
+
|
323 |
+
def message(self, s_i, v_i, s_j, v_j, edge_attr):
|
324 |
+
v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
|
325 |
+
v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
|
326 |
+
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
|
327 |
+
message = self.message_func(message)
|
328 |
+
return _merge(*message)
|
329 |
+
|
330 |
+
|
331 |
+
class GVPConvLayer(nn.Module):
|
332 |
+
'''
|
333 |
+
Full graph convolution / message passing layer with
|
334 |
+
Geometric Vector Perceptrons. Residually updates node embeddings with
|
335 |
+
aggregated incoming messages, applies a pointwise feedforward
|
336 |
+
network to node embeddings, and returns updated node embeddings.
|
337 |
+
|
338 |
+
To only compute the aggregated messages, see `GVPConv`.
|
339 |
+
|
340 |
+
:param node_dims: node embedding dimensions (n_scalar, n_vector)
|
341 |
+
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
|
342 |
+
:param n_message: number of GVPs to use in message function
|
343 |
+
:param n_feedforward: number of GVPs to use in feedforward function
|
344 |
+
:param drop_rate: drop probability in all dropout layers
|
345 |
+
:param autoregressive: if `True`, this `GVPConvLayer` will be used
|
346 |
+
with a different set of input node embeddings for messages
|
347 |
+
where src >= dst
|
348 |
+
'''
|
349 |
+
def __init__(self, node_dims, edge_dims, vector_gate=False,
|
350 |
+
n_message=3, n_feedforward=2, drop_rate=.1,
|
351 |
+
autoregressive=False, attention_heads=0,
|
352 |
+
conv_activations=(F.relu, torch.sigmoid),
|
353 |
+
n_edge_gvps=0, layernorm=True, eps=1e-8):
|
354 |
+
|
355 |
+
super(GVPConvLayer, self).__init__()
|
356 |
+
if attention_heads == 0:
|
357 |
+
self.conv = GVPConv(
|
358 |
+
node_dims, node_dims, edge_dims, n_layers=n_message,
|
359 |
+
vector_gate=vector_gate,
|
360 |
+
aggr="add" if autoregressive else "mean",
|
361 |
+
activations=conv_activations,
|
362 |
+
eps=eps,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
raise NotImplementedError
|
366 |
+
if layernorm:
|
367 |
+
self.norm = nn.ModuleList([LayerNorm(node_dims, eps=eps) for _ in range(2)])
|
368 |
+
else:
|
369 |
+
self.norm = nn.ModuleList([nn.Identity() for _ in range(2)])
|
370 |
+
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
|
371 |
+
|
372 |
+
ff_func = []
|
373 |
+
if n_feedforward == 1:
|
374 |
+
ff_func.append(GVP(node_dims, node_dims, activations=(None, None)))
|
375 |
+
else:
|
376 |
+
hid_dims = 4*node_dims[0], 2*node_dims[1]
|
377 |
+
ff_func.append(GVP(node_dims, hid_dims, vector_gate=vector_gate))
|
378 |
+
for i in range(n_feedforward-2):
|
379 |
+
ff_func.append(GVP(hid_dims, hid_dims, vector_gate=vector_gate))
|
380 |
+
ff_func.append(GVP(hid_dims, node_dims, activations=(None, None)))
|
381 |
+
self.ff_func = nn.Sequential(*ff_func)
|
382 |
+
|
383 |
+
self.edge_message_func = None
|
384 |
+
if n_edge_gvps > 0:
|
385 |
+
si, vi = node_dims
|
386 |
+
se, ve = edge_dims
|
387 |
+
module_list = [
|
388 |
+
GVP((2*si + se, 2*vi + ve), edge_dims, vector_gate=vector_gate)
|
389 |
+
]
|
390 |
+
for i in range(n_edge_gvps - 2):
|
391 |
+
module_list.append(GVP(edge_dims, edge_dims,
|
392 |
+
vector_gate=vector_gate))
|
393 |
+
if n_edge_gvps > 1:
|
394 |
+
module_list.append(GVP(edge_dims, edge_dims,
|
395 |
+
activations=(None, None)))
|
396 |
+
self.edge_message_func = nn.Sequential(*module_list)
|
397 |
+
if layernorm:
|
398 |
+
self.edge_norm = LayerNorm(edge_dims, eps=eps)
|
399 |
+
else:
|
400 |
+
self.edge_norm = nn.Identity()
|
401 |
+
self.edge_dropout = Dropout(drop_rate)
|
402 |
+
|
403 |
+
def forward(self, x, edge_index, edge_attr,
|
404 |
+
autoregressive_x=None, node_mask=None):
|
405 |
+
'''
|
406 |
+
:param x: tuple (s, V) of `torch.Tensor`
|
407 |
+
:param edge_index: array of shape [2, n_edges]
|
408 |
+
:param edge_attr: tuple (s, V) of `torch.Tensor`
|
409 |
+
:param autoregressive_x: tuple (s, V) of `torch.Tensor`.
|
410 |
+
If not `None`, will be used as srcqq node embeddings
|
411 |
+
for forming messages where src >= dst. The corrent node
|
412 |
+
embeddings `x` will still be the base of the update and the
|
413 |
+
pointwise feedforward.
|
414 |
+
:param node_mask: array of type `bool` to index into the first
|
415 |
+
dim of node embeddings (s, V). If not `None`, only
|
416 |
+
these nodes will be updated.
|
417 |
+
'''
|
418 |
+
if self.edge_message_func:
|
419 |
+
src, dst = edge_index
|
420 |
+
if autoregressive_x is None:
|
421 |
+
x_src = x[0][src], x[1][src]
|
422 |
+
else:
|
423 |
+
mask = (src < dst).unsqueeze(-1)
|
424 |
+
x_src = (
|
425 |
+
torch.where(mask, x[0][src], autoregressive_x[0][src]),
|
426 |
+
torch.where(mask.unsqueeze(-1), x[1][src],
|
427 |
+
autoregressive_x[1][src])
|
428 |
+
)
|
429 |
+
x_dst = x[0][dst], x[1][dst]
|
430 |
+
x_edge = (
|
431 |
+
torch.cat([x_src[0], edge_attr[0], x_dst[0]], dim=-1),
|
432 |
+
torch.cat([x_src[1], edge_attr[1], x_dst[1]], dim=-2)
|
433 |
+
)
|
434 |
+
edge_attr_dh = self.edge_message_func(x_edge)
|
435 |
+
edge_attr = self.edge_norm(tuple_sum(edge_attr,
|
436 |
+
self.edge_dropout(edge_attr_dh)))
|
437 |
+
|
438 |
+
if autoregressive_x is not None:
|
439 |
+
# Guarding this import here to remove the dependency on torch_scatter, since this isn't used
|
440 |
+
# in ESM-IF1
|
441 |
+
from torch_scatter import scatter_add
|
442 |
+
src, dst = edge_index
|
443 |
+
mask = src < dst
|
444 |
+
edge_index_forward = edge_index[:, mask]
|
445 |
+
edge_index_backward = edge_index[:, ~mask]
|
446 |
+
edge_attr_forward = tuple_index(edge_attr, mask)
|
447 |
+
edge_attr_backward = tuple_index(edge_attr, ~mask)
|
448 |
+
|
449 |
+
dh = tuple_sum(
|
450 |
+
self.conv(x, edge_index_forward, edge_attr_forward),
|
451 |
+
self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
|
452 |
+
)
|
453 |
+
|
454 |
+
count = scatter_add(torch.ones_like(dst), dst,
|
455 |
+
dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
|
456 |
+
|
457 |
+
dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
|
458 |
+
|
459 |
+
else:
|
460 |
+
dh = self.conv(x, edge_index, edge_attr)
|
461 |
+
|
462 |
+
if node_mask is not None:
|
463 |
+
x_ = x
|
464 |
+
x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
|
465 |
+
|
466 |
+
x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
|
467 |
+
|
468 |
+
dh = self.ff_func(x)
|
469 |
+
x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
|
470 |
+
|
471 |
+
if node_mask is not None:
|
472 |
+
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
|
473 |
+
x = x_
|
474 |
+
|
475 |
+
return x, edge_attr
|
esm/inverse_folding/gvp_transformer.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch import Tensor
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from scipy.spatial import transform
|
13 |
+
|
14 |
+
from esm.data import Alphabet
|
15 |
+
|
16 |
+
from .features import DihedralFeatures
|
17 |
+
from .gvp_encoder import GVPEncoder
|
18 |
+
from .gvp_utils import unflatten_graph
|
19 |
+
from .gvp_transformer_encoder import GVPTransformerEncoder
|
20 |
+
from .transformer_decoder import TransformerDecoder
|
21 |
+
from .util import rotate, CoordBatchConverter
|
22 |
+
|
23 |
+
|
24 |
+
class GVPTransformerModel(nn.Module):
|
25 |
+
"""
|
26 |
+
GVP-Transformer inverse folding model.
|
27 |
+
|
28 |
+
Architecture: Geometric GVP-GNN as initial layers, followed by
|
29 |
+
sequence-to-sequence Transformer encoder and decoder.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, args, alphabet):
|
33 |
+
super().__init__()
|
34 |
+
encoder_embed_tokens = self.build_embedding(
|
35 |
+
args, alphabet, args.encoder_embed_dim,
|
36 |
+
)
|
37 |
+
decoder_embed_tokens = self.build_embedding(
|
38 |
+
args, alphabet, args.decoder_embed_dim,
|
39 |
+
)
|
40 |
+
encoder = self.build_encoder(args, alphabet, encoder_embed_tokens)
|
41 |
+
decoder = self.build_decoder(args, alphabet, decoder_embed_tokens)
|
42 |
+
self.args = args
|
43 |
+
self.encoder = encoder
|
44 |
+
self.decoder = decoder
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
48 |
+
encoder = GVPTransformerEncoder(args, src_dict, embed_tokens)
|
49 |
+
return encoder
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
53 |
+
decoder = TransformerDecoder(
|
54 |
+
args,
|
55 |
+
tgt_dict,
|
56 |
+
embed_tokens,
|
57 |
+
)
|
58 |
+
return decoder
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def build_embedding(cls, args, dictionary, embed_dim):
|
62 |
+
num_embeddings = len(dictionary)
|
63 |
+
padding_idx = dictionary.padding_idx
|
64 |
+
emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
|
65 |
+
nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5)
|
66 |
+
nn.init.constant_(emb.weight[padding_idx], 0)
|
67 |
+
return emb
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
coords,
|
72 |
+
padding_mask,
|
73 |
+
confidence,
|
74 |
+
prev_output_tokens,
|
75 |
+
return_all_hiddens: bool = False,
|
76 |
+
features_only: bool = False,
|
77 |
+
):
|
78 |
+
encoder_out = self.encoder(coords, padding_mask, confidence,
|
79 |
+
return_all_hiddens=return_all_hiddens)
|
80 |
+
logits, extra = self.decoder(
|
81 |
+
prev_output_tokens,
|
82 |
+
encoder_out=encoder_out,
|
83 |
+
features_only=features_only,
|
84 |
+
return_all_hiddens=return_all_hiddens,
|
85 |
+
)
|
86 |
+
return logits, extra
|
87 |
+
|
88 |
+
def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None):
|
89 |
+
"""
|
90 |
+
Samples sequences based on multinomial sampling (no beam search).
|
91 |
+
|
92 |
+
Args:
|
93 |
+
coords: L x 3 x 3 list representing one backbone
|
94 |
+
partial_seq: Optional, partial sequence with mask tokens if part of
|
95 |
+
the sequence is known
|
96 |
+
temperature: sampling temperature, use low temperature for higher
|
97 |
+
sequence recovery and high temperature for higher diversity
|
98 |
+
confidence: optional length L list of confidence scores for coordinates
|
99 |
+
"""
|
100 |
+
L = len(coords)
|
101 |
+
# Convert to batch format
|
102 |
+
batch_converter = CoordBatchConverter(self.decoder.dictionary)
|
103 |
+
batch_coords, confidence, _, _, padding_mask = (
|
104 |
+
batch_converter([(coords, confidence, None)], device=device)
|
105 |
+
)
|
106 |
+
|
107 |
+
# Start with prepend token
|
108 |
+
mask_idx = self.decoder.dictionary.get_idx('<mask>')
|
109 |
+
sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int)
|
110 |
+
sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>')
|
111 |
+
if partial_seq is not None:
|
112 |
+
for i, c in enumerate(partial_seq):
|
113 |
+
sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c)
|
114 |
+
|
115 |
+
# Save incremental states for faster sampling
|
116 |
+
incremental_state = dict()
|
117 |
+
|
118 |
+
# Run encoder only once
|
119 |
+
encoder_out = self.encoder(batch_coords, padding_mask, confidence)
|
120 |
+
|
121 |
+
# Make sure all tensors are on the same device if a GPU is present
|
122 |
+
if device:
|
123 |
+
sampled_tokens = sampled_tokens.to(device)
|
124 |
+
|
125 |
+
# Decode one token at a time
|
126 |
+
for i in range(1, L+1):
|
127 |
+
logits, _ = self.decoder(
|
128 |
+
sampled_tokens[:, :i],
|
129 |
+
encoder_out,
|
130 |
+
incremental_state=incremental_state,
|
131 |
+
)
|
132 |
+
logits = logits[0].transpose(0, 1)
|
133 |
+
logits /= temperature
|
134 |
+
probs = F.softmax(logits, dim=-1)
|
135 |
+
if sampled_tokens[0, i] == mask_idx:
|
136 |
+
sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
|
137 |
+
sampled_seq = sampled_tokens[0, 1:]
|
138 |
+
|
139 |
+
# Convert back to string via lookup
|
140 |
+
return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq])
|
esm/inverse_folding/gvp_transformer_encoder.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# Contents of this file were adapted from the open source fairseq repository.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import math
|
10 |
+
from typing import Dict, List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
from esm.modules import SinusoidalPositionalEmbedding
|
17 |
+
from .features import GVPInputFeaturizer, DihedralFeatures
|
18 |
+
from .gvp_encoder import GVPEncoder
|
19 |
+
from .transformer_layer import TransformerEncoderLayer
|
20 |
+
from .util import nan_to_num, get_rotation_frames, rotate, rbf
|
21 |
+
|
22 |
+
|
23 |
+
class GVPTransformerEncoder(nn.Module):
|
24 |
+
"""
|
25 |
+
Transformer encoder consisting of *args.encoder.layers* layers. Each layer
|
26 |
+
is a :class:`TransformerEncoderLayer`.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
args (argparse.Namespace): parsed command-line arguments
|
30 |
+
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
31 |
+
embed_tokens (torch.nn.Embedding): input embedding
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, args, dictionary, embed_tokens):
|
35 |
+
super().__init__()
|
36 |
+
self.args = args
|
37 |
+
self.dictionary = dictionary
|
38 |
+
|
39 |
+
self.dropout_module = nn.Dropout(args.dropout)
|
40 |
+
|
41 |
+
embed_dim = embed_tokens.embedding_dim
|
42 |
+
self.padding_idx = embed_tokens.padding_idx
|
43 |
+
|
44 |
+
self.embed_tokens = embed_tokens
|
45 |
+
self.embed_scale = math.sqrt(embed_dim)
|
46 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
47 |
+
embed_dim,
|
48 |
+
self.padding_idx,
|
49 |
+
)
|
50 |
+
self.embed_gvp_input_features = nn.Linear(15, embed_dim)
|
51 |
+
self.embed_confidence = nn.Linear(16, embed_dim)
|
52 |
+
self.embed_dihedrals = DihedralFeatures(embed_dim)
|
53 |
+
|
54 |
+
gvp_args = argparse.Namespace()
|
55 |
+
for k, v in vars(args).items():
|
56 |
+
if k.startswith("gvp_"):
|
57 |
+
setattr(gvp_args, k[4:], v)
|
58 |
+
self.gvp_encoder = GVPEncoder(gvp_args)
|
59 |
+
gvp_out_dim = gvp_args.node_hidden_dim_scalar + (3 *
|
60 |
+
gvp_args.node_hidden_dim_vector)
|
61 |
+
self.embed_gvp_output = nn.Linear(gvp_out_dim, embed_dim)
|
62 |
+
|
63 |
+
self.layers = nn.ModuleList([])
|
64 |
+
self.layers.extend(
|
65 |
+
[self.build_encoder_layer(args) for i in range(args.encoder_layers)]
|
66 |
+
)
|
67 |
+
self.num_layers = len(self.layers)
|
68 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
69 |
+
|
70 |
+
def build_encoder_layer(self, args):
|
71 |
+
return TransformerEncoderLayer(args)
|
72 |
+
|
73 |
+
def forward_embedding(self, coords, padding_mask, confidence):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
coords: N, CA, C backbone coordinates in shape length x 3 (atoms) x 3
|
77 |
+
padding_mask: boolean Tensor (true for padding) of shape length
|
78 |
+
confidence: confidence scores between 0 and 1 of shape length
|
79 |
+
"""
|
80 |
+
components = dict()
|
81 |
+
coord_mask = torch.all(torch.all(torch.isfinite(coords), dim=-1), dim=-1)
|
82 |
+
coords = nan_to_num(coords)
|
83 |
+
mask_tokens = (
|
84 |
+
padding_mask * self.dictionary.padding_idx +
|
85 |
+
~padding_mask * self.dictionary.get_idx("<mask>")
|
86 |
+
)
|
87 |
+
components["tokens"] = self.embed_tokens(mask_tokens) * self.embed_scale
|
88 |
+
components["diherals"] = self.embed_dihedrals(coords)
|
89 |
+
|
90 |
+
# GVP encoder
|
91 |
+
gvp_out_scalars, gvp_out_vectors = self.gvp_encoder(coords,
|
92 |
+
coord_mask, padding_mask, confidence)
|
93 |
+
R = get_rotation_frames(coords)
|
94 |
+
# Rotate to local rotation frame for rotation-invariance
|
95 |
+
gvp_out_features = torch.cat([
|
96 |
+
gvp_out_scalars,
|
97 |
+
rotate(gvp_out_vectors, R.transpose(-2, -1)).flatten(-2, -1),
|
98 |
+
], dim=-1)
|
99 |
+
components["gvp_out"] = self.embed_gvp_output(gvp_out_features)
|
100 |
+
|
101 |
+
components["confidence"] = self.embed_confidence(
|
102 |
+
rbf(confidence, 0., 1.))
|
103 |
+
|
104 |
+
# In addition to GVP encoder outputs, also directly embed GVP input node
|
105 |
+
# features to the Transformer
|
106 |
+
scalar_features, vector_features = GVPInputFeaturizer.get_node_features(
|
107 |
+
coords, coord_mask, with_coord_mask=False)
|
108 |
+
features = torch.cat([
|
109 |
+
scalar_features,
|
110 |
+
rotate(vector_features, R.transpose(-2, -1)).flatten(-2, -1),
|
111 |
+
], dim=-1)
|
112 |
+
components["gvp_input_features"] = self.embed_gvp_input_features(features)
|
113 |
+
|
114 |
+
embed = sum(components.values())
|
115 |
+
# for k, v in components.items():
|
116 |
+
# print(k, torch.mean(v, dim=(0,1)), torch.std(v, dim=(0,1)))
|
117 |
+
|
118 |
+
x = embed
|
119 |
+
x = x + self.embed_positions(mask_tokens)
|
120 |
+
x = self.dropout_module(x)
|
121 |
+
return x, components
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self,
|
125 |
+
coords,
|
126 |
+
encoder_padding_mask,
|
127 |
+
confidence,
|
128 |
+
return_all_hiddens: bool = False,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
coords (Tensor): backbone coordinates
|
133 |
+
shape batch_size x num_residues x num_atoms (3 for N, CA, C) x 3
|
134 |
+
encoder_padding_mask (ByteTensor): the positions of
|
135 |
+
padding elements of shape `(batch_size x num_residues)`
|
136 |
+
confidence (Tensor): the confidence score of shape (batch_size x
|
137 |
+
num_residues). The value is between 0. and 1. for each residue
|
138 |
+
coordinate, or -1. if no coordinate is given
|
139 |
+
return_all_hiddens (bool, optional): also return all of the
|
140 |
+
intermediate hidden states (default: False).
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
dict:
|
144 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
145 |
+
shape `(num_residues, batch_size, embed_dim)`
|
146 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
147 |
+
padding elements of shape `(batch_size, num_residues)`
|
148 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
149 |
+
of shape `(batch_size, num_residues, embed_dim)`
|
150 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
151 |
+
hidden states of shape `(num_residues, batch_size, embed_dim)`.
|
152 |
+
Only populated if *return_all_hiddens* is True.
|
153 |
+
"""
|
154 |
+
x, encoder_embedding = self.forward_embedding(coords,
|
155 |
+
encoder_padding_mask, confidence)
|
156 |
+
# account for padding while computing the representation
|
157 |
+
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
158 |
+
|
159 |
+
# B x T x C -> T x B x C
|
160 |
+
x = x.transpose(0, 1)
|
161 |
+
|
162 |
+
encoder_states = []
|
163 |
+
|
164 |
+
if return_all_hiddens:
|
165 |
+
encoder_states.append(x)
|
166 |
+
|
167 |
+
# encoder layers
|
168 |
+
for layer in self.layers:
|
169 |
+
x = layer(
|
170 |
+
x, encoder_padding_mask=encoder_padding_mask
|
171 |
+
)
|
172 |
+
if return_all_hiddens:
|
173 |
+
assert encoder_states is not None
|
174 |
+
encoder_states.append(x)
|
175 |
+
|
176 |
+
if self.layer_norm is not None:
|
177 |
+
x = self.layer_norm(x)
|
178 |
+
|
179 |
+
return {
|
180 |
+
"encoder_out": [x], # T x B x C
|
181 |
+
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
182 |
+
"encoder_embedding": [encoder_embedding], # dictionary
|
183 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
184 |
+
}
|
esm/inverse_folding/gvp_utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def flatten_graph(node_embeddings, edge_embeddings, edge_index):
|
10 |
+
"""
|
11 |
+
Flattens the graph into a batch size one (with disconnected subgraphs for
|
12 |
+
each example) to be compatible with pytorch-geometric package.
|
13 |
+
Args:
|
14 |
+
node_embeddings: node embeddings in tuple form (scalar, vector)
|
15 |
+
- scalar: shape batch size x nodes x node_embed_dim
|
16 |
+
- vector: shape batch size x nodes x node_embed_dim x 3
|
17 |
+
edge_embeddings: edge embeddings of in tuple form (scalar, vector)
|
18 |
+
- scalar: shape batch size x edges x edge_embed_dim
|
19 |
+
- vector: shape batch size x edges x edge_embed_dim x 3
|
20 |
+
edge_index: shape batch_size x 2 (source node and target node) x edges
|
21 |
+
Returns:
|
22 |
+
node_embeddings: node embeddings in tuple form (scalar, vector)
|
23 |
+
- scalar: shape batch total_nodes x node_embed_dim
|
24 |
+
- vector: shape batch total_nodes x node_embed_dim x 3
|
25 |
+
edge_embeddings: edge embeddings of in tuple form (scalar, vector)
|
26 |
+
- scalar: shape batch total_edges x edge_embed_dim
|
27 |
+
- vector: shape batch total_edges x edge_embed_dim x 3
|
28 |
+
edge_index: shape 2 x total_edges
|
29 |
+
"""
|
30 |
+
x_s, x_v = node_embeddings
|
31 |
+
e_s, e_v = edge_embeddings
|
32 |
+
batch_size, N = x_s.shape[0], x_s.shape[1]
|
33 |
+
node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1))
|
34 |
+
edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1))
|
35 |
+
|
36 |
+
edge_mask = torch.any(edge_index != -1, dim=1)
|
37 |
+
# Re-number the nodes by adding batch_idx * N to each batch
|
38 |
+
edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) *
|
39 |
+
N).unsqueeze(-1).unsqueeze(-1)
|
40 |
+
edge_index = edge_index.permute(1, 0, 2).flatten(1, 2)
|
41 |
+
edge_mask = edge_mask.flatten()
|
42 |
+
edge_index = edge_index[:, edge_mask]
|
43 |
+
edge_embeddings = (
|
44 |
+
edge_embeddings[0][edge_mask, :],
|
45 |
+
edge_embeddings[1][edge_mask, :]
|
46 |
+
)
|
47 |
+
return node_embeddings, edge_embeddings, edge_index
|
48 |
+
|
49 |
+
|
50 |
+
def unflatten_graph(node_embeddings, batch_size):
|
51 |
+
"""
|
52 |
+
Unflattens node embeddings.
|
53 |
+
Args:
|
54 |
+
node_embeddings: node embeddings in tuple form (scalar, vector)
|
55 |
+
- scalar: shape batch total_nodes x node_embed_dim
|
56 |
+
- vector: shape batch total_nodes x node_embed_dim x 3
|
57 |
+
batch_size: int
|
58 |
+
Returns:
|
59 |
+
node_embeddings: node embeddings in tuple form (scalar, vector)
|
60 |
+
- scalar: shape batch size x nodes x node_embed_dim
|
61 |
+
- vector: shape batch size x nodes x node_embed_dim x 3
|
62 |
+
"""
|
63 |
+
x_s, x_v = node_embeddings
|
64 |
+
x_s = x_s.reshape(batch_size, -1, x_s.shape[1])
|
65 |
+
x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2])
|
66 |
+
return (x_s, x_v)
|
67 |
+
|
68 |
+
|
esm/inverse_folding/multichain_util.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import biotite.structure
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from typing import Sequence, Tuple, List
|
10 |
+
|
11 |
+
from esm.inverse_folding.util import (
|
12 |
+
load_structure,
|
13 |
+
extract_coords_from_structure,
|
14 |
+
load_coords,
|
15 |
+
get_sequence_loss,
|
16 |
+
get_encoder_output,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def extract_coords_from_complex(structure: biotite.structure.AtomArray):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
structure: biotite AtomArray
|
24 |
+
Returns:
|
25 |
+
Tuple (coords_list, seq_list)
|
26 |
+
- coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
|
27 |
+
coordinates representing the backbone of each chain
|
28 |
+
- seqs: Dictionary mapping chain ids to native sequences of each chain
|
29 |
+
"""
|
30 |
+
coords = {}
|
31 |
+
seqs = {}
|
32 |
+
all_chains = biotite.structure.get_chains(structure)
|
33 |
+
for chain_id in all_chains:
|
34 |
+
chain = structure[structure.chain_id == chain_id]
|
35 |
+
coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain)
|
36 |
+
return coords, seqs
|
37 |
+
|
38 |
+
|
39 |
+
def load_complex_coords(fpath, chains):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
fpath: filepath to either pdb or cif file
|
43 |
+
chains: the chain ids (the order matters for autoregressive model)
|
44 |
+
Returns:
|
45 |
+
Tuple (coords_list, seq_list)
|
46 |
+
- coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
|
47 |
+
coordinates representing the backbone of each chain
|
48 |
+
- seqs: Dictionary mapping chain ids to native sequences of each chain
|
49 |
+
"""
|
50 |
+
structure = load_structure(fpath, chains)
|
51 |
+
return extract_coords_from_complex(structure)
|
52 |
+
|
53 |
+
|
54 |
+
def _concatenate_coords(coords, target_chain_id, padding_length=10):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
|
58 |
+
coordinates representing the backbone of each chain
|
59 |
+
target_chain_id: The chain id to sample sequences for
|
60 |
+
padding_length: Length of padding between concatenated chains
|
61 |
+
Returns:
|
62 |
+
Tuple (coords, seq)
|
63 |
+
- coords is an L x 3 x 3 array for N, CA, C coordinates, a
|
64 |
+
concatenation of the chains with padding in between
|
65 |
+
- seq is the extracted sequence, with padding tokens inserted
|
66 |
+
between the concatenated chains
|
67 |
+
"""
|
68 |
+
pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
|
69 |
+
# For best performance, put the target chain first in concatenation.
|
70 |
+
coords_list = [coords[target_chain_id]]
|
71 |
+
for chain_id in coords:
|
72 |
+
if chain_id == target_chain_id:
|
73 |
+
continue
|
74 |
+
coords_list.append(pad_coords)
|
75 |
+
coords_list.append(coords[chain_id])
|
76 |
+
coords_concatenated = np.concatenate(coords_list, axis=0)
|
77 |
+
return coords_concatenated
|
78 |
+
|
79 |
+
|
80 |
+
def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
|
81 |
+
padding_length=10):
|
82 |
+
"""
|
83 |
+
Samples sequence for one chain in a complex.
|
84 |
+
Args:
|
85 |
+
model: An instance of the GVPTransformer model
|
86 |
+
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
|
87 |
+
coordinates representing the backbone of each chain
|
88 |
+
target_chain_id: The chain id to sample sequences for
|
89 |
+
padding_length: padding length in between chains
|
90 |
+
Returns:
|
91 |
+
Sampled sequence for the target chain
|
92 |
+
"""
|
93 |
+
target_chain_len = coords[target_chain_id].shape[0]
|
94 |
+
all_coords = _concatenate_coords(coords, target_chain_id)
|
95 |
+
device = next(model.parameters()).device
|
96 |
+
|
97 |
+
# Supply padding tokens for other chains to avoid unused sampling for speed
|
98 |
+
padding_pattern = ['<pad>'] * all_coords.shape[0]
|
99 |
+
for i in range(target_chain_len):
|
100 |
+
padding_pattern[i] = '<mask>'
|
101 |
+
sampled = model.sample(all_coords, partial_seq=padding_pattern,
|
102 |
+
temperature=temperature, device=device)
|
103 |
+
sampled = sampled[:target_chain_len]
|
104 |
+
return sampled
|
105 |
+
|
106 |
+
|
107 |
+
def score_sequence_in_complex(model, alphabet, coords, target_chain_id,
|
108 |
+
target_seq, padding_length=10):
|
109 |
+
"""
|
110 |
+
Scores sequence for one chain in a complex.
|
111 |
+
Args:
|
112 |
+
model: An instance of the GVPTransformer model
|
113 |
+
alphabet: Alphabet for the model
|
114 |
+
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
|
115 |
+
coordinates representing the backbone of each chain
|
116 |
+
target_chain_id: The chain id to sample sequences for
|
117 |
+
target_seq: Target sequence for the target chain for scoring.
|
118 |
+
padding_length: padding length in between chains
|
119 |
+
Returns:
|
120 |
+
Tuple (ll_fullseq, ll_withcoord)
|
121 |
+
- ll_fullseq: Average log-likelihood over the full target chain
|
122 |
+
- ll_withcoord: Average log-likelihood in target chain excluding those
|
123 |
+
residues without coordinates
|
124 |
+
"""
|
125 |
+
all_coords = _concatenate_coords(coords, target_chain_id)
|
126 |
+
|
127 |
+
loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords,
|
128 |
+
target_seq)
|
129 |
+
ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(
|
130 |
+
~target_padding_mask)
|
131 |
+
|
132 |
+
# Also calculate average when excluding masked portions
|
133 |
+
coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2))
|
134 |
+
ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
|
135 |
+
return ll_fullseq, ll_withcoord
|
136 |
+
|
137 |
+
|
138 |
+
def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id):
|
139 |
+
"""
|
140 |
+
Args:
|
141 |
+
model: An instance of the GVPTransformer model
|
142 |
+
alphabet: Alphabet for the model
|
143 |
+
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
|
144 |
+
coordinates representing the backbone of each chain
|
145 |
+
target_chain_id: The chain id to sample sequences for
|
146 |
+
Returns:
|
147 |
+
Dictionary mapping chain id to encoder output for each chain
|
148 |
+
"""
|
149 |
+
all_coords = _concatenate_coords(coords, target_chain_id)
|
150 |
+
all_rep = get_encoder_output(model, alphabet, all_coords)
|
151 |
+
target_chain_len = coords[target_chain_id].shape[0]
|
152 |
+
return all_rep[:target_chain_len]
|
esm/inverse_folding/transformer_decoder.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# Contents of this file were adapted from the open source fairseq repository.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import math
|
9 |
+
from typing import Any, Dict, List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from esm.modules import SinusoidalPositionalEmbedding
|
16 |
+
from .transformer_layer import TransformerDecoderLayer
|
17 |
+
|
18 |
+
|
19 |
+
def fill_with_neg_inf(t):
|
20 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
21 |
+
return t.float().fill_(float("-inf")).type_as(t)
|
22 |
+
|
23 |
+
|
24 |
+
class TransformerDecoder(nn.Module):
|
25 |
+
"""
|
26 |
+
Transformer decoder consisting of *args.decoder.layers* layers. Each layer
|
27 |
+
is a :class:`TransformerDecoderLayer`.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
args (argparse.Namespace): parsed command-line arguments
|
31 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
32 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
33 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
34 |
+
(default: False).
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
args,
|
40 |
+
dictionary,
|
41 |
+
embed_tokens,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
self.args = args
|
45 |
+
self.dictionary = dictionary
|
46 |
+
self._future_mask = torch.empty(0)
|
47 |
+
|
48 |
+
self.dropout_module = nn.Dropout(args.dropout)
|
49 |
+
|
50 |
+
input_embed_dim = embed_tokens.embedding_dim
|
51 |
+
embed_dim = args.decoder_embed_dim
|
52 |
+
self.embed_dim = embed_dim
|
53 |
+
|
54 |
+
self.padding_idx = embed_tokens.padding_idx
|
55 |
+
|
56 |
+
self.embed_tokens = embed_tokens
|
57 |
+
self.embed_scale = math.sqrt(embed_dim)
|
58 |
+
|
59 |
+
self.project_in_dim = (
|
60 |
+
nn.Linear(input_embed_dim, embed_dim, bias=False)
|
61 |
+
if embed_dim != input_embed_dim
|
62 |
+
else None
|
63 |
+
)
|
64 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
65 |
+
embed_dim,
|
66 |
+
self.padding_idx,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.layers = nn.ModuleList([])
|
70 |
+
self.layers.extend(
|
71 |
+
[
|
72 |
+
self.build_decoder_layer(args)
|
73 |
+
for _ in range(args.decoder_layers)
|
74 |
+
]
|
75 |
+
)
|
76 |
+
self.num_layers = len(self.layers)
|
77 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
78 |
+
|
79 |
+
self.build_output_projection(args, dictionary)
|
80 |
+
|
81 |
+
def build_output_projection(self, args, dictionary):
|
82 |
+
self.output_projection = nn.Linear(
|
83 |
+
args.decoder_embed_dim, len(dictionary), bias=False
|
84 |
+
)
|
85 |
+
nn.init.normal_(
|
86 |
+
self.output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
87 |
+
)
|
88 |
+
|
89 |
+
def build_decoder_layer(self, args):
|
90 |
+
return TransformerDecoderLayer(args)
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
prev_output_tokens,
|
95 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
96 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
97 |
+
features_only: bool = False,
|
98 |
+
return_all_hiddens: bool = False,
|
99 |
+
):
|
100 |
+
"""
|
101 |
+
Args:
|
102 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
103 |
+
`(batch, tgt_len)`, for teacher forcing
|
104 |
+
encoder_out (optional): output from the encoder, used for
|
105 |
+
encoder-side attention, should be of size T x B x C
|
106 |
+
incremental_state (dict): dictionary used for storing state during
|
107 |
+
:ref:`Incremental decoding`
|
108 |
+
features_only (bool, optional): only return features without
|
109 |
+
applying output layer (default: False).
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
tuple:
|
113 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
114 |
+
- a dictionary with any model-specific outputs
|
115 |
+
"""
|
116 |
+
|
117 |
+
x, extra = self.extract_features(
|
118 |
+
prev_output_tokens,
|
119 |
+
encoder_out=encoder_out,
|
120 |
+
incremental_state=incremental_state,
|
121 |
+
)
|
122 |
+
|
123 |
+
if not features_only:
|
124 |
+
x = self.output_layer(x)
|
125 |
+
x = x.transpose(1, 2) # B x T x C -> B x C x T
|
126 |
+
return x, extra
|
127 |
+
|
128 |
+
def extract_features(
|
129 |
+
self,
|
130 |
+
prev_output_tokens,
|
131 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
132 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Similar to *forward* but only return features.
|
136 |
+
|
137 |
+
Includes several features from "Jointly Learning to Align and
|
138 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
tuple:
|
142 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
143 |
+
- a dictionary with any model-specific outputs
|
144 |
+
"""
|
145 |
+
bs, slen = prev_output_tokens.size()
|
146 |
+
|
147 |
+
enc: Optional[Tensor] = None
|
148 |
+
padding_mask: Optional[Tensor] = None
|
149 |
+
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
150 |
+
enc = encoder_out["encoder_out"][0]
|
151 |
+
assert (
|
152 |
+
enc.size()[1] == bs
|
153 |
+
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
|
154 |
+
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
155 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
156 |
+
|
157 |
+
# embed positions
|
158 |
+
positions = self.embed_positions(
|
159 |
+
prev_output_tokens
|
160 |
+
)
|
161 |
+
|
162 |
+
if incremental_state is not None:
|
163 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
164 |
+
positions = positions[:, -1:]
|
165 |
+
|
166 |
+
# embed tokens and positions
|
167 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
168 |
+
|
169 |
+
if self.project_in_dim is not None:
|
170 |
+
x = self.project_in_dim(x)
|
171 |
+
|
172 |
+
x += positions
|
173 |
+
|
174 |
+
x = self.dropout_module(x)
|
175 |
+
|
176 |
+
# B x T x C -> T x B x C
|
177 |
+
x = x.transpose(0, 1)
|
178 |
+
|
179 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
180 |
+
if prev_output_tokens.eq(self.padding_idx).any():
|
181 |
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
182 |
+
|
183 |
+
# decoder layers
|
184 |
+
attn: Optional[Tensor] = None
|
185 |
+
inner_states: List[Optional[Tensor]] = [x]
|
186 |
+
for idx, layer in enumerate(self.layers):
|
187 |
+
if incremental_state is None:
|
188 |
+
self_attn_mask = self.buffered_future_mask(x)
|
189 |
+
else:
|
190 |
+
self_attn_mask = None
|
191 |
+
|
192 |
+
x, layer_attn, _ = layer(
|
193 |
+
x,
|
194 |
+
enc,
|
195 |
+
padding_mask,
|
196 |
+
incremental_state,
|
197 |
+
self_attn_mask=self_attn_mask,
|
198 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
199 |
+
need_attn=False,
|
200 |
+
need_head_weights=False,
|
201 |
+
)
|
202 |
+
inner_states.append(x)
|
203 |
+
|
204 |
+
if self.layer_norm is not None:
|
205 |
+
x = self.layer_norm(x)
|
206 |
+
|
207 |
+
# T x B x C -> B x C x T
|
208 |
+
x = x.transpose(0, 1)
|
209 |
+
|
210 |
+
return x, {"inner_states": inner_states}
|
211 |
+
|
212 |
+
def output_layer(self, features):
|
213 |
+
"""Project features to the vocabulary size."""
|
214 |
+
return self.output_projection(features)
|
215 |
+
|
216 |
+
def buffered_future_mask(self, tensor):
|
217 |
+
dim = tensor.size(0)
|
218 |
+
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
219 |
+
if (
|
220 |
+
self._future_mask.size(0) == 0
|
221 |
+
or (not self._future_mask.device == tensor.device)
|
222 |
+
or self._future_mask.size(0) < dim
|
223 |
+
):
|
224 |
+
self._future_mask = torch.triu(
|
225 |
+
fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
226 |
+
)
|
227 |
+
self._future_mask = self._future_mask.to(tensor)
|
228 |
+
return self._future_mask[:dim, :dim]
|
esm/inverse_folding/transformer_layer.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# Contents of this file were adapted from the open source fairseq repository.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from esm.multihead_attention import MultiheadAttention
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
|
17 |
+
class TransformerEncoderLayer(nn.Module):
|
18 |
+
"""Encoder layer block.
|
19 |
+
`layernorm -> dropout -> add residual`
|
20 |
+
|
21 |
+
Args:
|
22 |
+
args (argparse.Namespace): parsed command-line arguments
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, args):
|
26 |
+
super().__init__()
|
27 |
+
self.args = args
|
28 |
+
self.embed_dim = args.encoder_embed_dim
|
29 |
+
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
30 |
+
self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim)
|
31 |
+
self.dropout_module = nn.Dropout(args.dropout)
|
32 |
+
self.activation_fn = F.relu
|
33 |
+
self.fc1 = self.build_fc1(
|
34 |
+
self.embed_dim,
|
35 |
+
args.encoder_ffn_embed_dim,
|
36 |
+
)
|
37 |
+
self.fc2 = self.build_fc2(
|
38 |
+
args.encoder_ffn_embed_dim,
|
39 |
+
self.embed_dim,
|
40 |
+
)
|
41 |
+
|
42 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
43 |
+
|
44 |
+
def build_fc1(self, input_dim, output_dim):
|
45 |
+
return nn.Linear(input_dim, output_dim)
|
46 |
+
|
47 |
+
def build_fc2(self, input_dim, output_dim):
|
48 |
+
return nn.Linear(input_dim, output_dim)
|
49 |
+
|
50 |
+
def build_self_attention(self, embed_dim, args):
|
51 |
+
return MultiheadAttention(
|
52 |
+
embed_dim,
|
53 |
+
args.encoder_attention_heads,
|
54 |
+
dropout=args.attention_dropout,
|
55 |
+
self_attention=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
def residual_connection(self, x, residual):
|
59 |
+
return residual + x
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
x,
|
64 |
+
encoder_padding_mask: Optional[Tensor],
|
65 |
+
attn_mask: Optional[Tensor] = None,
|
66 |
+
):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
70 |
+
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
71 |
+
`(batch, seq_len)` where padding elements are indicated by ``1``.
|
72 |
+
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
|
73 |
+
where `tgt_len` is the length of output and `src_len` is the
|
74 |
+
length of input, though here both are equal to `seq_len`.
|
75 |
+
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
|
76 |
+
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
|
77 |
+
useful for strided self-attention.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
81 |
+
"""
|
82 |
+
# anything in original attn_mask = 1, becomes -1e8
|
83 |
+
# anything in original attn_mask = 0, becomes 0
|
84 |
+
# Note that we cannot use -inf here, because at some edge cases,
|
85 |
+
# the attention weight (before softmax) for some padded element in query
|
86 |
+
# will become -inf, which results in NaN in model parameters
|
87 |
+
if attn_mask is not None:
|
88 |
+
attn_mask = attn_mask.masked_fill(
|
89 |
+
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
|
90 |
+
)
|
91 |
+
|
92 |
+
residual = x
|
93 |
+
x = self.self_attn_layer_norm(x)
|
94 |
+
x, _ = self.self_attn(
|
95 |
+
query=x,
|
96 |
+
key=x,
|
97 |
+
value=x,
|
98 |
+
key_padding_mask=encoder_padding_mask,
|
99 |
+
need_weights=False,
|
100 |
+
attn_mask=attn_mask,
|
101 |
+
)
|
102 |
+
x = self.dropout_module(x)
|
103 |
+
x = self.residual_connection(x, residual)
|
104 |
+
|
105 |
+
residual = x
|
106 |
+
x = self.final_layer_norm(x)
|
107 |
+
x = self.activation_fn(self.fc1(x))
|
108 |
+
x = self.fc2(x)
|
109 |
+
x = self.dropout_module(x)
|
110 |
+
x = self.residual_connection(x, residual)
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
class TransformerDecoderLayer(nn.Module):
|
115 |
+
"""Decoder layer block.
|
116 |
+
`layernorm -> dropout -> add residual`
|
117 |
+
|
118 |
+
Args:
|
119 |
+
args (argparse.Namespace): parsed command-line arguments
|
120 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
121 |
+
(default: False).
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.embed_dim = args.decoder_embed_dim
|
129 |
+
self.dropout_module = nn.Dropout(args.dropout)
|
130 |
+
|
131 |
+
self.self_attn = self.build_self_attention(
|
132 |
+
self.embed_dim,
|
133 |
+
args,
|
134 |
+
add_bias_kv=add_bias_kv,
|
135 |
+
add_zero_attn=add_zero_attn,
|
136 |
+
)
|
137 |
+
self.nh = self.self_attn.num_heads
|
138 |
+
self.head_dim = self.self_attn.head_dim
|
139 |
+
|
140 |
+
self.activation_fn = F.relu
|
141 |
+
|
142 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
143 |
+
|
144 |
+
if no_encoder_attn:
|
145 |
+
self.encoder_attn = None
|
146 |
+
self.encoder_attn_layer_norm = None
|
147 |
+
else:
|
148 |
+
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
149 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
150 |
+
|
151 |
+
self.ffn_layernorm = (
|
152 |
+
LayerNorm(args.decoder_ffn_embed_dim)
|
153 |
+
if getattr(args, "scale_fc", False)
|
154 |
+
else None
|
155 |
+
)
|
156 |
+
self.w_resid = (
|
157 |
+
nn.Parameter(
|
158 |
+
torch.ones(
|
159 |
+
self.embed_dim,
|
160 |
+
),
|
161 |
+
requires_grad=True,
|
162 |
+
)
|
163 |
+
if getattr(args, "scale_resids", False)
|
164 |
+
else None
|
165 |
+
)
|
166 |
+
|
167 |
+
self.fc1 = self.build_fc1(
|
168 |
+
self.embed_dim,
|
169 |
+
args.decoder_ffn_embed_dim,
|
170 |
+
)
|
171 |
+
self.fc2 = self.build_fc2(
|
172 |
+
args.decoder_ffn_embed_dim,
|
173 |
+
self.embed_dim,
|
174 |
+
)
|
175 |
+
|
176 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
177 |
+
self.need_attn = True
|
178 |
+
|
179 |
+
def build_fc1(self, input_dim, output_dim):
|
180 |
+
return nn.Linear(input_dim, output_dim)
|
181 |
+
|
182 |
+
def build_fc2(self, input_dim, output_dim):
|
183 |
+
return nn.Linear(input_dim, output_dim)
|
184 |
+
|
185 |
+
def build_self_attention(
|
186 |
+
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
|
187 |
+
):
|
188 |
+
return MultiheadAttention(
|
189 |
+
embed_dim,
|
190 |
+
args.decoder_attention_heads,
|
191 |
+
dropout=args.attention_dropout,
|
192 |
+
add_bias_kv=add_bias_kv,
|
193 |
+
add_zero_attn=add_zero_attn,
|
194 |
+
self_attention=True,
|
195 |
+
)
|
196 |
+
|
197 |
+
def build_encoder_attention(self, embed_dim, args):
|
198 |
+
return MultiheadAttention(
|
199 |
+
embed_dim,
|
200 |
+
args.decoder_attention_heads,
|
201 |
+
kdim=args.encoder_embed_dim,
|
202 |
+
vdim=args.encoder_embed_dim,
|
203 |
+
dropout=args.attention_dropout,
|
204 |
+
encoder_decoder_attention=True,
|
205 |
+
)
|
206 |
+
|
207 |
+
def residual_connection(self, x, residual):
|
208 |
+
return residual + x
|
209 |
+
|
210 |
+
def forward(
|
211 |
+
self,
|
212 |
+
x,
|
213 |
+
encoder_out: Optional[torch.Tensor] = None,
|
214 |
+
encoder_padding_mask: Optional[torch.Tensor] = None,
|
215 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
216 |
+
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
|
217 |
+
prev_attn_state: Optional[List[torch.Tensor]] = None,
|
218 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
219 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
220 |
+
need_attn: bool = False,
|
221 |
+
need_head_weights: bool = False,
|
222 |
+
):
|
223 |
+
"""
|
224 |
+
Args:
|
225 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
226 |
+
encoder_padding_mask (ByteTensor, optional): binary
|
227 |
+
ByteTensor of shape `(batch, src_len)` where padding
|
228 |
+
elements are indicated by ``1``.
|
229 |
+
need_attn (bool, optional): return attention weights
|
230 |
+
need_head_weights (bool, optional): return attention weights
|
231 |
+
for each head (default: return average over heads).
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
235 |
+
"""
|
236 |
+
if need_head_weights:
|
237 |
+
need_attn = True
|
238 |
+
|
239 |
+
residual = x
|
240 |
+
x = self.self_attn_layer_norm(x)
|
241 |
+
if prev_self_attn_state is not None:
|
242 |
+
prev_key, prev_value = prev_self_attn_state[:2]
|
243 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
244 |
+
"prev_key": prev_key,
|
245 |
+
"prev_value": prev_value,
|
246 |
+
}
|
247 |
+
if len(prev_self_attn_state) >= 3:
|
248 |
+
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
249 |
+
assert incremental_state is not None
|
250 |
+
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
251 |
+
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
|
252 |
+
y = x
|
253 |
+
|
254 |
+
x, attn = self.self_attn(
|
255 |
+
query=x,
|
256 |
+
key=y,
|
257 |
+
value=y,
|
258 |
+
key_padding_mask=self_attn_padding_mask,
|
259 |
+
incremental_state=incremental_state,
|
260 |
+
need_weights=False,
|
261 |
+
attn_mask=self_attn_mask,
|
262 |
+
)
|
263 |
+
x = self.dropout_module(x)
|
264 |
+
x = self.residual_connection(x, residual)
|
265 |
+
|
266 |
+
if self.encoder_attn is not None and encoder_out is not None:
|
267 |
+
residual = x
|
268 |
+
x = self.encoder_attn_layer_norm(x)
|
269 |
+
if prev_attn_state is not None:
|
270 |
+
prev_key, prev_value = prev_attn_state[:2]
|
271 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
272 |
+
"prev_key": prev_key,
|
273 |
+
"prev_value": prev_value,
|
274 |
+
}
|
275 |
+
if len(prev_attn_state) >= 3:
|
276 |
+
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
277 |
+
assert incremental_state is not None
|
278 |
+
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
279 |
+
|
280 |
+
x, attn = self.encoder_attn(
|
281 |
+
query=x,
|
282 |
+
key=encoder_out,
|
283 |
+
value=encoder_out,
|
284 |
+
key_padding_mask=encoder_padding_mask,
|
285 |
+
incremental_state=incremental_state,
|
286 |
+
static_kv=True,
|
287 |
+
need_weights=need_attn or (not self.training and self.need_attn),
|
288 |
+
need_head_weights=need_head_weights,
|
289 |
+
)
|
290 |
+
x = self.dropout_module(x)
|
291 |
+
x = self.residual_connection(x, residual)
|
292 |
+
|
293 |
+
residual = x
|
294 |
+
x = self.final_layer_norm(x)
|
295 |
+
|
296 |
+
x = self.activation_fn(self.fc1(x))
|
297 |
+
if self.ffn_layernorm is not None:
|
298 |
+
x = self.ffn_layernorm(x)
|
299 |
+
x = self.fc2(x)
|
300 |
+
x = self.dropout_module(x)
|
301 |
+
if self.w_resid is not None:
|
302 |
+
residual = torch.mul(self.w_resid, residual)
|
303 |
+
x = self.residual_connection(x, residual)
|
304 |
+
return x, attn, None
|
esm/inverse_folding/util.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
|
9 |
+
import biotite.structure
|
10 |
+
from biotite.structure.io import pdbx, pdb
|
11 |
+
from biotite.structure.residues import get_residues
|
12 |
+
from biotite.structure import filter_backbone
|
13 |
+
from biotite.structure import get_chains
|
14 |
+
from biotite.sequence import ProteinSequence
|
15 |
+
import numpy as np
|
16 |
+
from scipy.spatial import transform
|
17 |
+
from scipy.stats import special_ortho_group
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.utils.data as data
|
22 |
+
from typing import Sequence, Tuple, List
|
23 |
+
|
24 |
+
from esm.data import BatchConverter
|
25 |
+
|
26 |
+
|
27 |
+
def load_structure(fpath, chain=None):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
fpath: filepath to either pdb or cif file
|
31 |
+
chain: the chain id or list of chain ids to load
|
32 |
+
Returns:
|
33 |
+
biotite.structure.AtomArray
|
34 |
+
"""
|
35 |
+
if fpath.endswith('cif'):
|
36 |
+
with open(fpath) as fin:
|
37 |
+
pdbxf = pdbx.PDBxFile.read(fin)
|
38 |
+
structure = pdbx.get_structure(pdbxf, model=1)
|
39 |
+
elif fpath.endswith('pdb'):
|
40 |
+
with open(fpath) as fin:
|
41 |
+
pdbf = pdb.PDBFile.read(fin)
|
42 |
+
structure = pdb.get_structure(pdbf, model=1)
|
43 |
+
bbmask = filter_backbone(structure)
|
44 |
+
structure = structure[bbmask]
|
45 |
+
all_chains = get_chains(structure)
|
46 |
+
if len(all_chains) == 0:
|
47 |
+
raise ValueError('No chains found in the input file.')
|
48 |
+
if chain is None:
|
49 |
+
chain_ids = all_chains
|
50 |
+
elif isinstance(chain, list):
|
51 |
+
chain_ids = chain
|
52 |
+
else:
|
53 |
+
chain_ids = [chain]
|
54 |
+
for chain in chain_ids:
|
55 |
+
if chain not in all_chains:
|
56 |
+
raise ValueError(f'Chain {chain} not found in input file')
|
57 |
+
chain_filter = [a.chain_id in chain_ids for a in structure]
|
58 |
+
structure = structure[chain_filter]
|
59 |
+
return structure
|
60 |
+
|
61 |
+
|
62 |
+
def extract_coords_from_structure(structure: biotite.structure.AtomArray):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
structure: An instance of biotite AtomArray
|
66 |
+
Returns:
|
67 |
+
Tuple (coords, seq)
|
68 |
+
- coords is an L x 3 x 3 array for N, CA, C coordinates
|
69 |
+
- seq is the extracted sequence
|
70 |
+
"""
|
71 |
+
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
|
72 |
+
residue_identities = get_residues(structure)[1]
|
73 |
+
seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
|
74 |
+
return coords, seq
|
75 |
+
|
76 |
+
|
77 |
+
def load_coords(fpath, chain):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
fpath: filepath to either pdb or cif file
|
81 |
+
chain: the chain id
|
82 |
+
Returns:
|
83 |
+
Tuple (coords, seq)
|
84 |
+
- coords is an L x 3 x 3 array for N, CA, C coordinates
|
85 |
+
- seq is the extracted sequence
|
86 |
+
"""
|
87 |
+
structure = load_structure(fpath, chain)
|
88 |
+
return extract_coords_from_structure(structure)
|
89 |
+
|
90 |
+
|
91 |
+
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
|
92 |
+
"""
|
93 |
+
Example for atoms argument: ["N", "CA", "C"]
|
94 |
+
"""
|
95 |
+
def filterfn(s, axis=None):
|
96 |
+
filters = np.stack([s.atom_name == name for name in atoms], axis=1)
|
97 |
+
sum = filters.sum(0)
|
98 |
+
if not np.all(sum <= np.ones(filters.shape[1])):
|
99 |
+
raise RuntimeError("structure has multiple atoms with same name")
|
100 |
+
index = filters.argmax(0)
|
101 |
+
coords = s[index].coord
|
102 |
+
coords[sum == 0] = float("nan")
|
103 |
+
return coords
|
104 |
+
|
105 |
+
return biotite.structure.apply_residue_wise(struct, struct, filterfn)
|
106 |
+
|
107 |
+
|
108 |
+
def get_sequence_loss(model, alphabet, coords, seq):
|
109 |
+
device = next(model.parameters()).device
|
110 |
+
batch_converter = CoordBatchConverter(alphabet)
|
111 |
+
batch = [(coords, None, seq)]
|
112 |
+
coords, confidence, strs, tokens, padding_mask = batch_converter(
|
113 |
+
batch, device=device)
|
114 |
+
|
115 |
+
prev_output_tokens = tokens[:, :-1].to(device)
|
116 |
+
target = tokens[:, 1:]
|
117 |
+
target_padding_mask = (target == alphabet.padding_idx)
|
118 |
+
logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
|
119 |
+
loss = F.cross_entropy(logits, target, reduction='none')
|
120 |
+
loss = loss[0].cpu().detach().numpy()
|
121 |
+
target_padding_mask = target_padding_mask[0].cpu().numpy()
|
122 |
+
return loss, target_padding_mask
|
123 |
+
|
124 |
+
|
125 |
+
def score_sequence(model, alphabet, coords, seq):
|
126 |
+
loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq)
|
127 |
+
ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
|
128 |
+
# Also calculate average when excluding masked portions
|
129 |
+
coord_mask = np.all(np.isfinite(coords), axis=(-1, -2))
|
130 |
+
ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
|
131 |
+
return ll_fullseq, ll_withcoord
|
132 |
+
|
133 |
+
|
134 |
+
def get_encoder_output(model, alphabet, coords):
|
135 |
+
device = next(model.parameters()).device
|
136 |
+
batch_converter = CoordBatchConverter(alphabet)
|
137 |
+
batch = [(coords, None, None)]
|
138 |
+
coords, confidence, strs, tokens, padding_mask = batch_converter(
|
139 |
+
batch, device=device)
|
140 |
+
encoder_out = model.encoder.forward(coords, padding_mask, confidence,
|
141 |
+
return_all_hiddens=False)
|
142 |
+
# remove beginning and end (bos and eos tokens)
|
143 |
+
return encoder_out['encoder_out'][0][1:-1, 0]
|
144 |
+
|
145 |
+
|
146 |
+
def rotate(v, R):
|
147 |
+
"""
|
148 |
+
Rotates a vector by a rotation matrix.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
v: 3D vector, tensor of shape (length x batch_size x channels x 3)
|
152 |
+
R: rotation matrix, tensor of shape (length x batch_size x 3 x 3)
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
Rotated version of v by rotation matrix R.
|
156 |
+
"""
|
157 |
+
R = R.unsqueeze(-3)
|
158 |
+
v = v.unsqueeze(-1)
|
159 |
+
return torch.sum(v * R, dim=-2)
|
160 |
+
|
161 |
+
|
162 |
+
def get_rotation_frames(coords):
|
163 |
+
"""
|
164 |
+
Returns a local rotation frame defined by N, CA, C positions.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
coords: coordinates, tensor of shape (batch_size x length x 3 x 3)
|
168 |
+
where the third dimension is in order of N, CA, C
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Local relative rotation frames in shape (batch_size x length x 3 x 3)
|
172 |
+
"""
|
173 |
+
v1 = coords[:, :, 2] - coords[:, :, 1]
|
174 |
+
v2 = coords[:, :, 0] - coords[:, :, 1]
|
175 |
+
e1 = normalize(v1, dim=-1)
|
176 |
+
u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True)
|
177 |
+
e2 = normalize(u2, dim=-1)
|
178 |
+
e3 = torch.cross(e1, e2, dim=-1)
|
179 |
+
R = torch.stack([e1, e2, e3], dim=-2)
|
180 |
+
return R
|
181 |
+
|
182 |
+
|
183 |
+
def nan_to_num(ts, val=0.0):
|
184 |
+
"""
|
185 |
+
Replaces nans in tensor with a fixed value.
|
186 |
+
"""
|
187 |
+
val = torch.tensor(val, dtype=ts.dtype, device=ts.device)
|
188 |
+
return torch.where(~torch.isfinite(ts), val, ts)
|
189 |
+
|
190 |
+
|
191 |
+
def rbf(values, v_min, v_max, n_bins=16):
|
192 |
+
"""
|
193 |
+
Returns RBF encodings in a new dimension at the end.
|
194 |
+
"""
|
195 |
+
rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device)
|
196 |
+
rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
|
197 |
+
rbf_std = (v_max - v_min) / n_bins
|
198 |
+
v_expand = torch.unsqueeze(values, -1)
|
199 |
+
z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
|
200 |
+
return torch.exp(-z ** 2)
|
201 |
+
|
202 |
+
|
203 |
+
def norm(tensor, dim, eps=1e-8, keepdim=False):
|
204 |
+
"""
|
205 |
+
Returns L2 norm along a dimension.
|
206 |
+
"""
|
207 |
+
return torch.sqrt(
|
208 |
+
torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps)
|
209 |
+
|
210 |
+
|
211 |
+
def normalize(tensor, dim=-1):
|
212 |
+
"""
|
213 |
+
Normalizes a tensor along a dimension after removing nans.
|
214 |
+
"""
|
215 |
+
return nan_to_num(
|
216 |
+
torch.div(tensor, norm(tensor, dim=dim, keepdim=True))
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
class CoordBatchConverter(BatchConverter):
|
221 |
+
def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
|
222 |
+
"""
|
223 |
+
Args:
|
224 |
+
raw_batch: List of tuples (coords, confidence, seq)
|
225 |
+
In each tuple,
|
226 |
+
coords: list of floats, shape L x 3 x 3
|
227 |
+
confidence: list of floats, shape L; or scalar float; or None
|
228 |
+
seq: string of length L
|
229 |
+
Returns:
|
230 |
+
coords: Tensor of shape batch_size x L x 3 x 3
|
231 |
+
confidence: Tensor of shape batch_size x L
|
232 |
+
strs: list of strings
|
233 |
+
tokens: LongTensor of shape batch_size x L
|
234 |
+
padding_mask: ByteTensor of shape batch_size x L
|
235 |
+
"""
|
236 |
+
self.alphabet.cls_idx = self.alphabet.get_idx("<cath>")
|
237 |
+
batch = []
|
238 |
+
for coords, confidence, seq in raw_batch:
|
239 |
+
if confidence is None:
|
240 |
+
confidence = 1.
|
241 |
+
if isinstance(confidence, float) or isinstance(confidence, int):
|
242 |
+
confidence = [float(confidence)] * len(coords)
|
243 |
+
if seq is None:
|
244 |
+
seq = 'X' * len(coords)
|
245 |
+
batch.append(((coords, confidence), seq))
|
246 |
+
|
247 |
+
coords_and_confidence, strs, tokens = super().__call__(batch)
|
248 |
+
|
249 |
+
# pad beginning and end of each protein due to legacy reasons
|
250 |
+
coords = [
|
251 |
+
F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)
|
252 |
+
for cd, _ in coords_and_confidence
|
253 |
+
]
|
254 |
+
confidence = [
|
255 |
+
F.pad(torch.tensor(cf), (1, 1), value=-1.)
|
256 |
+
for _, cf in coords_and_confidence
|
257 |
+
]
|
258 |
+
coords = self.collate_dense_tensors(coords, pad_v=np.nan)
|
259 |
+
confidence = self.collate_dense_tensors(confidence, pad_v=-1.)
|
260 |
+
if device is not None:
|
261 |
+
coords = coords.to(device)
|
262 |
+
confidence = confidence.to(device)
|
263 |
+
tokens = tokens.to(device)
|
264 |
+
padding_mask = torch.isnan(coords[:,:,0,0])
|
265 |
+
coord_mask = torch.isfinite(coords.sum(-2).sum(-1))
|
266 |
+
confidence = confidence * coord_mask + (-1.) * padding_mask
|
267 |
+
return coords, confidence, strs, tokens, padding_mask
|
268 |
+
|
269 |
+
def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None):
|
270 |
+
"""
|
271 |
+
Args:
|
272 |
+
coords_list: list of length batch_size, each item is a list of
|
273 |
+
floats in shape L x 3 x 3 to describe a backbone
|
274 |
+
confidence_list: one of
|
275 |
+
- None, default to highest confidence
|
276 |
+
- list of length batch_size, each item is a scalar
|
277 |
+
- list of length batch_size, each item is a list of floats of
|
278 |
+
length L to describe the confidence scores for the backbone
|
279 |
+
with values between 0. and 1.
|
280 |
+
seq_list: either None or a list of strings
|
281 |
+
Returns:
|
282 |
+
coords: Tensor of shape batch_size x L x 3 x 3
|
283 |
+
confidence: Tensor of shape batch_size x L
|
284 |
+
strs: list of strings
|
285 |
+
tokens: LongTensor of shape batch_size x L
|
286 |
+
padding_mask: ByteTensor of shape batch_size x L
|
287 |
+
"""
|
288 |
+
batch_size = len(coords_list)
|
289 |
+
if confidence_list is None:
|
290 |
+
confidence_list = [None] * batch_size
|
291 |
+
if seq_list is None:
|
292 |
+
seq_list = [None] * batch_size
|
293 |
+
raw_batch = zip(coords_list, confidence_list, seq_list)
|
294 |
+
return self.__call__(raw_batch, device)
|
295 |
+
|
296 |
+
@staticmethod
|
297 |
+
def collate_dense_tensors(samples, pad_v):
|
298 |
+
"""
|
299 |
+
Takes a list of tensors with the following dimensions:
|
300 |
+
[(d_11, ..., d_1K),
|
301 |
+
(d_21, ..., d_2K),
|
302 |
+
...,
|
303 |
+
(d_N1, ..., d_NK)]
|
304 |
+
and stack + pads them into a single tensor of:
|
305 |
+
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
|
306 |
+
"""
|
307 |
+
if len(samples) == 0:
|
308 |
+
return torch.Tensor()
|
309 |
+
if len(set(x.dim() for x in samples)) != 1:
|
310 |
+
raise RuntimeError(
|
311 |
+
f"Samples has varying dimensions: {[x.dim() for x in samples]}"
|
312 |
+
)
|
313 |
+
(device,) = tuple(set(x.device for x in samples)) # assumes all on same device
|
314 |
+
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
|
315 |
+
result = torch.empty(
|
316 |
+
len(samples), *max_shape, dtype=samples[0].dtype, device=device
|
317 |
+
)
|
318 |
+
result.fill_(pad_v)
|
319 |
+
for i in range(len(samples)):
|
320 |
+
result_i = result[i]
|
321 |
+
t = samples[i]
|
322 |
+
result_i[tuple(slice(0, k) for k in t.shape)] = t
|
323 |
+
return result
|
esm/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
esm/model/esm1.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from ..modules import (
|
13 |
+
TransformerLayer,
|
14 |
+
LearnedPositionalEmbedding,
|
15 |
+
SinusoidalPositionalEmbedding,
|
16 |
+
RobertaLMHead,
|
17 |
+
ESM1bLayerNorm,
|
18 |
+
ContactPredictionHead,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class ProteinBertModel(nn.Module):
|
23 |
+
@classmethod
|
24 |
+
def add_args(cls, parser):
|
25 |
+
parser.add_argument(
|
26 |
+
"--num_layers", default=36, type=int, metavar="N", help="number of layers"
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--logit_bias", action="store_true", help="whether to apply bias to logits"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--ffn_embed_dim",
|
36 |
+
default=5120,
|
37 |
+
type=int,
|
38 |
+
metavar="N",
|
39 |
+
help="embedding dimension for FFN",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--attention_heads",
|
43 |
+
default=20,
|
44 |
+
type=int,
|
45 |
+
metavar="N",
|
46 |
+
help="number of attention heads",
|
47 |
+
)
|
48 |
+
|
49 |
+
def __init__(self, args, alphabet):
|
50 |
+
super().__init__()
|
51 |
+
self.args = args
|
52 |
+
self.alphabet_size = len(alphabet)
|
53 |
+
self.padding_idx = alphabet.padding_idx
|
54 |
+
self.mask_idx = alphabet.mask_idx
|
55 |
+
self.cls_idx = alphabet.cls_idx
|
56 |
+
self.eos_idx = alphabet.eos_idx
|
57 |
+
self.prepend_bos = alphabet.prepend_bos
|
58 |
+
self.append_eos = alphabet.append_eos
|
59 |
+
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
|
60 |
+
if self.args.arch == "roberta_large":
|
61 |
+
self.model_version = "ESM-1b"
|
62 |
+
self._init_submodules_esm1b()
|
63 |
+
else:
|
64 |
+
self.model_version = "ESM-1"
|
65 |
+
self._init_submodules_esm1()
|
66 |
+
|
67 |
+
def _init_submodules_common(self):
|
68 |
+
self.embed_tokens = nn.Embedding(
|
69 |
+
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
|
70 |
+
)
|
71 |
+
self.layers = nn.ModuleList(
|
72 |
+
[
|
73 |
+
TransformerLayer(
|
74 |
+
self.args.embed_dim,
|
75 |
+
self.args.ffn_embed_dim,
|
76 |
+
self.args.attention_heads,
|
77 |
+
add_bias_kv=(self.model_version != "ESM-1b"),
|
78 |
+
use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
|
79 |
+
)
|
80 |
+
for _ in range(self.args.layers)
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
self.contact_head = ContactPredictionHead(
|
85 |
+
self.args.layers * self.args.attention_heads,
|
86 |
+
self.prepend_bos,
|
87 |
+
self.append_eos,
|
88 |
+
eos_idx=self.eos_idx,
|
89 |
+
)
|
90 |
+
|
91 |
+
def _init_submodules_esm1b(self):
|
92 |
+
self._init_submodules_common()
|
93 |
+
self.embed_scale = 1
|
94 |
+
self.embed_positions = LearnedPositionalEmbedding(
|
95 |
+
self.args.max_positions, self.args.embed_dim, self.padding_idx
|
96 |
+
)
|
97 |
+
self.emb_layer_norm_before = (
|
98 |
+
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
|
99 |
+
)
|
100 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
|
101 |
+
self.lm_head = RobertaLMHead(
|
102 |
+
embed_dim=self.args.embed_dim,
|
103 |
+
output_dim=self.alphabet_size,
|
104 |
+
weight=self.embed_tokens.weight,
|
105 |
+
)
|
106 |
+
|
107 |
+
def _init_submodules_esm1(self):
|
108 |
+
self._init_submodules_common()
|
109 |
+
self.embed_scale = math.sqrt(self.args.embed_dim)
|
110 |
+
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
|
111 |
+
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
|
112 |
+
self.embed_out_bias = None
|
113 |
+
if self.args.final_bias:
|
114 |
+
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))
|
115 |
+
|
116 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
|
117 |
+
if return_contacts:
|
118 |
+
need_head_weights = True
|
119 |
+
|
120 |
+
assert tokens.ndim == 2
|
121 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
122 |
+
|
123 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
124 |
+
|
125 |
+
if getattr(self.args, "token_dropout", False):
|
126 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
127 |
+
# x: B x T x C
|
128 |
+
mask_ratio_train = 0.15 * 0.8
|
129 |
+
src_lengths = (~padding_mask).sum(-1)
|
130 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
|
131 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
132 |
+
|
133 |
+
x = x + self.embed_positions(tokens)
|
134 |
+
|
135 |
+
if self.model_version == "ESM-1b":
|
136 |
+
if self.emb_layer_norm_before:
|
137 |
+
x = self.emb_layer_norm_before(x)
|
138 |
+
if padding_mask is not None:
|
139 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
140 |
+
|
141 |
+
repr_layers = set(repr_layers)
|
142 |
+
hidden_representations = {}
|
143 |
+
if 0 in repr_layers:
|
144 |
+
hidden_representations[0] = x
|
145 |
+
|
146 |
+
if need_head_weights:
|
147 |
+
attn_weights = []
|
148 |
+
|
149 |
+
# (B, T, E) => (T, B, E)
|
150 |
+
x = x.transpose(0, 1)
|
151 |
+
|
152 |
+
if not padding_mask.any():
|
153 |
+
padding_mask = None
|
154 |
+
|
155 |
+
for layer_idx, layer in enumerate(self.layers):
|
156 |
+
x, attn = layer(
|
157 |
+
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
|
158 |
+
)
|
159 |
+
if (layer_idx + 1) in repr_layers:
|
160 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
161 |
+
if need_head_weights:
|
162 |
+
# (H, B, T, T) => (B, H, T, T)
|
163 |
+
attn_weights.append(attn.transpose(1, 0))
|
164 |
+
|
165 |
+
if self.model_version == "ESM-1b":
|
166 |
+
x = self.emb_layer_norm_after(x)
|
167 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
168 |
+
|
169 |
+
# last hidden representation should have layer norm applied
|
170 |
+
if (layer_idx + 1) in repr_layers:
|
171 |
+
hidden_representations[layer_idx + 1] = x
|
172 |
+
x = self.lm_head(x)
|
173 |
+
else:
|
174 |
+
x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
|
175 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
176 |
+
|
177 |
+
result = {"logits": x, "representations": hidden_representations}
|
178 |
+
if need_head_weights:
|
179 |
+
# attentions: B x L x H x T x T
|
180 |
+
attentions = torch.stack(attn_weights, 1)
|
181 |
+
if self.model_version == "ESM-1":
|
182 |
+
# ESM-1 models have an additional null-token for attention, which we remove
|
183 |
+
attentions = attentions[..., :-1]
|
184 |
+
if padding_mask is not None:
|
185 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
186 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
187 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
188 |
+
result["attentions"] = attentions
|
189 |
+
if return_contacts:
|
190 |
+
contacts = self.contact_head(tokens, attentions)
|
191 |
+
result["contacts"] = contacts
|
192 |
+
|
193 |
+
return result
|
194 |
+
|
195 |
+
def predict_contacts(self, tokens):
|
196 |
+
return self(tokens, return_contacts=True)["contacts"]
|
197 |
+
|
198 |
+
@property
|
199 |
+
def num_layers(self):
|
200 |
+
return self.args.layers
|
esm/model/esm2.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Union
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
import esm
|
11 |
+
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
|
12 |
+
|
13 |
+
|
14 |
+
class ESM2(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
num_layers: int = 33,
|
18 |
+
embed_dim: int = 1280,
|
19 |
+
attention_heads: int = 20,
|
20 |
+
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
|
21 |
+
token_dropout: bool = True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.num_layers = num_layers
|
25 |
+
self.embed_dim = embed_dim
|
26 |
+
self.attention_heads = attention_heads
|
27 |
+
if not isinstance(alphabet, esm.data.Alphabet):
|
28 |
+
alphabet = esm.data.Alphabet.from_architecture(alphabet)
|
29 |
+
self.alphabet = alphabet
|
30 |
+
self.alphabet_size = len(alphabet)
|
31 |
+
self.padding_idx = alphabet.padding_idx
|
32 |
+
self.mask_idx = alphabet.mask_idx
|
33 |
+
self.cls_idx = alphabet.cls_idx
|
34 |
+
self.eos_idx = alphabet.eos_idx
|
35 |
+
self.prepend_bos = alphabet.prepend_bos
|
36 |
+
self.append_eos = alphabet.append_eos
|
37 |
+
self.token_dropout = token_dropout
|
38 |
+
|
39 |
+
self._init_submodules()
|
40 |
+
|
41 |
+
def _init_submodules(self):
|
42 |
+
self.embed_scale = 1
|
43 |
+
self.embed_tokens = nn.Embedding(
|
44 |
+
self.alphabet_size,
|
45 |
+
self.embed_dim,
|
46 |
+
padding_idx=self.padding_idx,
|
47 |
+
)
|
48 |
+
|
49 |
+
self.layers = nn.ModuleList(
|
50 |
+
[
|
51 |
+
TransformerLayer(
|
52 |
+
self.embed_dim,
|
53 |
+
4 * self.embed_dim,
|
54 |
+
self.attention_heads,
|
55 |
+
add_bias_kv=False,
|
56 |
+
use_esm1b_layer_norm=True,
|
57 |
+
use_rotary_embeddings=True,
|
58 |
+
)
|
59 |
+
for _ in range(self.num_layers)
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
self.contact_head = ContactPredictionHead(
|
64 |
+
self.num_layers * self.attention_heads,
|
65 |
+
self.prepend_bos,
|
66 |
+
self.append_eos,
|
67 |
+
eos_idx=self.eos_idx,
|
68 |
+
)
|
69 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
|
70 |
+
|
71 |
+
self.lm_head = RobertaLMHead(
|
72 |
+
embed_dim=self.embed_dim,
|
73 |
+
output_dim=self.alphabet_size,
|
74 |
+
weight=self.embed_tokens.weight,
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
|
78 |
+
if return_contacts:
|
79 |
+
need_head_weights = True
|
80 |
+
|
81 |
+
assert tokens.ndim == 2
|
82 |
+
padding_mask = tokens.eq(self.padding_idx) # B, T
|
83 |
+
|
84 |
+
x = self.embed_scale * self.embed_tokens(tokens)
|
85 |
+
|
86 |
+
if self.token_dropout:
|
87 |
+
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
88 |
+
# x: B x T x C
|
89 |
+
mask_ratio_train = 0.15 * 0.8
|
90 |
+
src_lengths = (~padding_mask).sum(-1)
|
91 |
+
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
|
92 |
+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
93 |
+
|
94 |
+
if padding_mask is not None:
|
95 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
96 |
+
|
97 |
+
repr_layers = set(repr_layers)
|
98 |
+
hidden_representations = {}
|
99 |
+
if 0 in repr_layers:
|
100 |
+
hidden_representations[0] = x
|
101 |
+
|
102 |
+
if need_head_weights:
|
103 |
+
attn_weights = []
|
104 |
+
|
105 |
+
# (B, T, E) => (T, B, E)
|
106 |
+
x = x.transpose(0, 1)
|
107 |
+
|
108 |
+
if not padding_mask.any():
|
109 |
+
padding_mask = None
|
110 |
+
|
111 |
+
for layer_idx, layer in enumerate(self.layers):
|
112 |
+
x, attn = layer(
|
113 |
+
x,
|
114 |
+
self_attn_padding_mask=padding_mask,
|
115 |
+
need_head_weights=need_head_weights,
|
116 |
+
)
|
117 |
+
if (layer_idx + 1) in repr_layers:
|
118 |
+
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
|
119 |
+
if need_head_weights:
|
120 |
+
# (H, B, T, T) => (B, H, T, T)
|
121 |
+
attn_weights.append(attn.transpose(1, 0))
|
122 |
+
|
123 |
+
x = self.emb_layer_norm_after(x)
|
124 |
+
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
|
125 |
+
|
126 |
+
# last hidden representation should have layer norm applied
|
127 |
+
if (layer_idx + 1) in repr_layers:
|
128 |
+
hidden_representations[layer_idx + 1] = x
|
129 |
+
x = self.lm_head(x)
|
130 |
+
|
131 |
+
result = {"logits": x, "representations": hidden_representations}
|
132 |
+
if need_head_weights:
|
133 |
+
# attentions: B x L x H x T x T
|
134 |
+
attentions = torch.stack(attn_weights, 1)
|
135 |
+
if padding_mask is not None:
|
136 |
+
attention_mask = 1 - padding_mask.type_as(attentions)
|
137 |
+
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
138 |
+
attentions = attentions * attention_mask[:, None, None, :, :]
|
139 |
+
result["attentions"] = attentions
|
140 |
+
if return_contacts:
|
141 |
+
contacts = self.contact_head(tokens, attentions)
|
142 |
+
result["contacts"] = contacts
|
143 |
+
|
144 |
+
return result
|
145 |
+
|
146 |
+
def predict_contacts(self, tokens):
|
147 |
+
return self(tokens, return_contacts=True)["contacts"]
|
esm/model/msa_transformer.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from ..modules import (
|
10 |
+
AxialTransformerLayer,
|
11 |
+
LearnedPositionalEmbedding,
|
12 |
+
RobertaLMHead,
|
13 |
+
ESM1bLayerNorm,
|
14 |
+
ContactPredictionHead,
|
15 |
+
)
|
16 |
+
|
17 |
+
from ..axial_attention import RowSelfAttention, ColumnSelfAttention
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class MSATransformer(nn.Module):
|
22 |
+
@classmethod
|
23 |
+
def add_args(cls, parser):
|
24 |
+
# fmt: off
|
25 |
+
parser.add_argument(
|
26 |
+
"--num_layers",
|
27 |
+
default=12,
|
28 |
+
type=int,
|
29 |
+
metavar="N",
|
30 |
+
help="number of layers"
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--embed_dim",
|
34 |
+
default=768,
|
35 |
+
type=int,
|
36 |
+
metavar="N",
|
37 |
+
help="embedding dimension"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--logit_bias",
|
41 |
+
action="store_true",
|
42 |
+
help="whether to apply bias to logits"
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--ffn_embed_dim",
|
46 |
+
default=3072,
|
47 |
+
type=int,
|
48 |
+
metavar="N",
|
49 |
+
help="embedding dimension for FFN",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--attention_heads",
|
53 |
+
default=12,
|
54 |
+
type=int,
|
55 |
+
metavar="N",
|
56 |
+
help="number of attention heads",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--dropout",
|
60 |
+
default=0.1,
|
61 |
+
type=float,
|
62 |
+
help="Dropout to apply."
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--attention_dropout",
|
66 |
+
default=0.1,
|
67 |
+
type=float,
|
68 |
+
help="Dropout to apply."
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--activation_dropout",
|
72 |
+
default=0.1,
|
73 |
+
type=float,
|
74 |
+
help="Dropout to apply."
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--max_tokens_per_msa",
|
78 |
+
default=2 ** 14,
|
79 |
+
type=int,
|
80 |
+
help=(
|
81 |
+
"Used during inference to batch attention computations in a single "
|
82 |
+
"forward pass. This allows increased input sizes with less memory."
|
83 |
+
),
|
84 |
+
)
|
85 |
+
# fmt: on
|
86 |
+
|
87 |
+
def __init__(self, args, alphabet):
|
88 |
+
super().__init__()
|
89 |
+
self.args = args
|
90 |
+
self.alphabet_size = len(alphabet)
|
91 |
+
self.padding_idx = alphabet.padding_idx
|
92 |
+
self.mask_idx = alphabet.mask_idx
|
93 |
+
self.cls_idx = alphabet.cls_idx
|
94 |
+
self.eos_idx = alphabet.eos_idx
|
95 |
+
self.prepend_bos = alphabet.prepend_bos
|
96 |
+
self.append_eos = alphabet.append_eos
|
97 |
+
|
98 |
+
self.embed_tokens = nn.Embedding(
|
99 |
+
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
|
100 |
+
)
|
101 |
+
|
102 |
+
if getattr(self.args, "embed_positions_msa", False):
|
103 |
+
emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim)
|
104 |
+
self.msa_position_embedding = nn.Parameter(
|
105 |
+
0.01 * torch.randn(1, 1024, 1, emb_dim),
|
106 |
+
requires_grad=True,
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
self.register_parameter("msa_position_embedding", None)
|
110 |
+
|
111 |
+
self.dropout_module = nn.Dropout(self.args.dropout)
|
112 |
+
self.layers = nn.ModuleList(
|
113 |
+
[
|
114 |
+
AxialTransformerLayer(
|
115 |
+
self.args.embed_dim,
|
116 |
+
self.args.ffn_embed_dim,
|
117 |
+
self.args.attention_heads,
|
118 |
+
self.args.dropout,
|
119 |
+
self.args.attention_dropout,
|
120 |
+
self.args.activation_dropout,
|
121 |
+
getattr(self.args, "max_tokens_per_msa", self.args.max_tokens),
|
122 |
+
)
|
123 |
+
for _ in range(self.args.layers)
|
124 |
+
]
|
125 |
+
)
|
126 |
+
|
127 |
+
self.contact_head = ContactPredictionHead(
|
128 |
+
self.args.layers * self.args.attention_heads,
|
129 |
+
self.prepend_bos,
|
130 |
+
self.append_eos,
|
131 |
+
eos_idx=self.eos_idx,
|
132 |
+
)
|
133 |
+
self.embed_positions = LearnedPositionalEmbedding(
|
134 |
+
self.args.max_positions,
|
135 |
+
self.args.embed_dim,
|
136 |
+
self.padding_idx,
|
137 |
+
)
|
138 |
+
self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim)
|
139 |
+
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
|
140 |
+
self.lm_head = RobertaLMHead(
|
141 |
+
embed_dim=self.args.embed_dim,
|
142 |
+
output_dim=self.alphabet_size,
|
143 |
+
weight=self.embed_tokens.weight,
|
144 |
+
)
|
145 |
+
|
146 |
+
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
|
147 |
+
if return_contacts:
|
148 |
+
need_head_weights = True
|
149 |
+
|
150 |
+
assert tokens.ndim == 3
|
151 |
+
batch_size, num_alignments, seqlen = tokens.size()
|
152 |
+
padding_mask = tokens.eq(self.padding_idx) # B, R, C
|
153 |
+
if not padding_mask.any():
|
154 |
+
padding_mask = None
|
155 |
+
|
156 |
+
x = self.embed_tokens(tokens)
|
157 |
+
x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
|
158 |
+
if self.msa_position_embedding is not None:
|
159 |
+
if x.size(1) > 1024:
|
160 |
+
raise RuntimeError(
|
161 |
+
"Using model with MSA position embedding trained on maximum MSA "
|
162 |
+
f"depth of 1024, but received {x.size(1)} alignments."
|
163 |
+
)
|
164 |
+
x += self.msa_position_embedding[:, :num_alignments]
|
165 |
+
|
166 |
+
x = self.emb_layer_norm_before(x)
|
167 |
+
|
168 |
+
x = self.dropout_module(x)
|
169 |
+
|
170 |
+
if padding_mask is not None:
|
171 |
+
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
|
172 |
+
|
173 |
+
repr_layers = set(repr_layers)
|
174 |
+
hidden_representations = {}
|
175 |
+
if 0 in repr_layers:
|
176 |
+
hidden_representations[0] = x
|
177 |
+
|
178 |
+
if need_head_weights:
|
179 |
+
row_attn_weights = []
|
180 |
+
col_attn_weights = []
|
181 |
+
|
182 |
+
# B x R x C x D -> R x C x B x D
|
183 |
+
x = x.permute(1, 2, 0, 3)
|
184 |
+
|
185 |
+
for layer_idx, layer in enumerate(self.layers):
|
186 |
+
x = layer(
|
187 |
+
x,
|
188 |
+
self_attn_padding_mask=padding_mask,
|
189 |
+
need_head_weights=need_head_weights,
|
190 |
+
)
|
191 |
+
if need_head_weights:
|
192 |
+
x, col_attn, row_attn = x
|
193 |
+
# H x C x B x R x R -> B x H x C x R x R
|
194 |
+
col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
|
195 |
+
# H x B x C x C -> B x H x C x C
|
196 |
+
row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
|
197 |
+
if (layer_idx + 1) in repr_layers:
|
198 |
+
hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
|
199 |
+
|
200 |
+
x = self.emb_layer_norm_after(x)
|
201 |
+
x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
|
202 |
+
|
203 |
+
# last hidden representation should have layer norm applied
|
204 |
+
if (layer_idx + 1) in repr_layers:
|
205 |
+
hidden_representations[layer_idx + 1] = x
|
206 |
+
x = self.lm_head(x)
|
207 |
+
|
208 |
+
result = {"logits": x, "representations": hidden_representations}
|
209 |
+
if need_head_weights:
|
210 |
+
# col_attentions: B x L x H x C x R x R
|
211 |
+
col_attentions = torch.stack(col_attn_weights, 1)
|
212 |
+
# row_attentions: B x L x H x C x C
|
213 |
+
row_attentions = torch.stack(row_attn_weights, 1)
|
214 |
+
result["col_attentions"] = col_attentions
|
215 |
+
result["row_attentions"] = row_attentions
|
216 |
+
if return_contacts:
|
217 |
+
contacts = self.contact_head(tokens, row_attentions)
|
218 |
+
result["contacts"] = contacts
|
219 |
+
|
220 |
+
return result
|
221 |
+
|
222 |
+
def predict_contacts(self, tokens):
|
223 |
+
return self(tokens, return_contacts=True)["contacts"]
|
224 |
+
|
225 |
+
@property
|
226 |
+
def num_layers(self):
|
227 |
+
return self.args.layers
|
228 |
+
|
229 |
+
def max_tokens_per_msa_(self, value: int) -> None:
|
230 |
+
"""The MSA Transformer automatically batches attention computations when
|
231 |
+
gradients are disabled to allow you to pass in larger MSAs at test time than
|
232 |
+
you can fit in GPU memory. By default this occurs when more than 2^14 tokens
|
233 |
+
are passed in the input MSA. You can set this value to infinity to disable
|
234 |
+
this behavior.
|
235 |
+
"""
|
236 |
+
for module in self.modules():
|
237 |
+
if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
|
238 |
+
module.max_tokens_per_msa = value
|
esm/modules.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from .multihead_attention import MultiheadAttention # noqa
|
14 |
+
from .axial_attention import ColumnSelfAttention, RowSelfAttention
|
15 |
+
|
16 |
+
|
17 |
+
def gelu(x):
|
18 |
+
"""Implementation of the gelu activation function.
|
19 |
+
|
20 |
+
For information: OpenAI GPT's gelu is slightly different
|
21 |
+
(and gives slightly different results):
|
22 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
23 |
+
"""
|
24 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
25 |
+
|
26 |
+
|
27 |
+
def symmetrize(x):
|
28 |
+
"Make layer symmetric in final two dimensions, used for contact prediction."
|
29 |
+
return x + x.transpose(-1, -2)
|
30 |
+
|
31 |
+
|
32 |
+
def apc(x):
|
33 |
+
"Perform average product correct, used for contact prediction."
|
34 |
+
a1 = x.sum(-1, keepdims=True)
|
35 |
+
a2 = x.sum(-2, keepdims=True)
|
36 |
+
a12 = x.sum((-1, -2), keepdims=True)
|
37 |
+
|
38 |
+
avg = a1 * a2
|
39 |
+
avg.div_(a12) # in-place to reduce memory
|
40 |
+
normalized = x - avg
|
41 |
+
return normalized
|
42 |
+
|
43 |
+
|
44 |
+
class ESM1LayerNorm(nn.Module):
|
45 |
+
def __init__(self, hidden_size, eps=1e-12, affine=True):
|
46 |
+
"""Construct a layernorm layer in the TF style (eps inside the sqrt)."""
|
47 |
+
super().__init__()
|
48 |
+
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
|
49 |
+
self.eps = eps
|
50 |
+
self.affine = bool(affine)
|
51 |
+
if self.affine:
|
52 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
53 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
54 |
+
else:
|
55 |
+
self.weight, self.bias = None, None
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
|
59 |
+
means = x.mean(dims, keepdim=True)
|
60 |
+
x_zeromean = x - means
|
61 |
+
variances = x_zeromean.pow(2).mean(dims, keepdim=True)
|
62 |
+
x = x_zeromean / torch.sqrt(variances + self.eps)
|
63 |
+
if self.affine:
|
64 |
+
x = (self.weight * x) + self.bias
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
try:
|
69 |
+
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
70 |
+
|
71 |
+
class ESM1bLayerNorm(_FusedLayerNorm):
|
72 |
+
@torch.jit.unused
|
73 |
+
def forward(self, x):
|
74 |
+
if not x.is_cuda:
|
75 |
+
return super().forward(x)
|
76 |
+
else:
|
77 |
+
with torch.cuda.device(x.device):
|
78 |
+
return super().forward(x)
|
79 |
+
|
80 |
+
except ImportError:
|
81 |
+
from torch.nn import LayerNorm as ESM1bLayerNorm
|
82 |
+
|
83 |
+
|
84 |
+
class TransformerLayer(nn.Module):
|
85 |
+
"""Transformer layer block."""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
embed_dim,
|
90 |
+
ffn_embed_dim,
|
91 |
+
attention_heads,
|
92 |
+
add_bias_kv=True,
|
93 |
+
use_esm1b_layer_norm=False,
|
94 |
+
use_rotary_embeddings: bool = False,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.embed_dim = embed_dim
|
98 |
+
self.ffn_embed_dim = ffn_embed_dim
|
99 |
+
self.attention_heads = attention_heads
|
100 |
+
self.use_rotary_embeddings = use_rotary_embeddings
|
101 |
+
self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
|
102 |
+
|
103 |
+
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
|
104 |
+
BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
|
105 |
+
|
106 |
+
self.self_attn = MultiheadAttention(
|
107 |
+
self.embed_dim,
|
108 |
+
self.attention_heads,
|
109 |
+
add_bias_kv=add_bias_kv,
|
110 |
+
add_zero_attn=False,
|
111 |
+
use_rotary_embeddings=self.use_rotary_embeddings,
|
112 |
+
)
|
113 |
+
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
|
114 |
+
|
115 |
+
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
|
116 |
+
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
|
117 |
+
|
118 |
+
self.final_layer_norm = BertLayerNorm(self.embed_dim)
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
|
122 |
+
):
|
123 |
+
residual = x
|
124 |
+
x = self.self_attn_layer_norm(x)
|
125 |
+
x, attn = self.self_attn(
|
126 |
+
query=x,
|
127 |
+
key=x,
|
128 |
+
value=x,
|
129 |
+
key_padding_mask=self_attn_padding_mask,
|
130 |
+
need_weights=True,
|
131 |
+
need_head_weights=need_head_weights,
|
132 |
+
attn_mask=self_attn_mask,
|
133 |
+
)
|
134 |
+
x = residual + x
|
135 |
+
|
136 |
+
residual = x
|
137 |
+
x = self.final_layer_norm(x)
|
138 |
+
x = gelu(self.fc1(x))
|
139 |
+
x = self.fc2(x)
|
140 |
+
x = residual + x
|
141 |
+
|
142 |
+
return x, attn
|
143 |
+
|
144 |
+
|
145 |
+
class AxialTransformerLayer(nn.Module):
|
146 |
+
"""Implements an Axial MSA Transformer block."""
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
embedding_dim: int = 768,
|
151 |
+
ffn_embedding_dim: int = 3072,
|
152 |
+
num_attention_heads: int = 8,
|
153 |
+
dropout: float = 0.1,
|
154 |
+
attention_dropout: float = 0.1,
|
155 |
+
activation_dropout: float = 0.1,
|
156 |
+
max_tokens_per_msa: int = 2**14,
|
157 |
+
) -> None:
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
# Initialize parameters
|
161 |
+
self.embedding_dim = embedding_dim
|
162 |
+
self.dropout_prob = dropout
|
163 |
+
|
164 |
+
row_self_attention = RowSelfAttention(
|
165 |
+
embedding_dim,
|
166 |
+
num_attention_heads,
|
167 |
+
dropout=dropout,
|
168 |
+
max_tokens_per_msa=max_tokens_per_msa,
|
169 |
+
)
|
170 |
+
|
171 |
+
column_self_attention = ColumnSelfAttention(
|
172 |
+
embedding_dim,
|
173 |
+
num_attention_heads,
|
174 |
+
dropout=dropout,
|
175 |
+
max_tokens_per_msa=max_tokens_per_msa,
|
176 |
+
)
|
177 |
+
|
178 |
+
feed_forward_layer = FeedForwardNetwork(
|
179 |
+
embedding_dim,
|
180 |
+
ffn_embedding_dim,
|
181 |
+
activation_dropout=activation_dropout,
|
182 |
+
max_tokens_per_msa=max_tokens_per_msa,
|
183 |
+
)
|
184 |
+
|
185 |
+
self.row_self_attention = self.build_residual(row_self_attention)
|
186 |
+
self.column_self_attention = self.build_residual(column_self_attention)
|
187 |
+
self.feed_forward_layer = self.build_residual(feed_forward_layer)
|
188 |
+
|
189 |
+
def build_residual(self, layer: nn.Module):
|
190 |
+
return NormalizedResidualBlock(
|
191 |
+
layer,
|
192 |
+
self.embedding_dim,
|
193 |
+
self.dropout_prob,
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(
|
197 |
+
self,
|
198 |
+
x: torch.Tensor,
|
199 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
200 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
201 |
+
need_head_weights: bool = False,
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
205 |
+
modules similar to the original Transformer implementation.
|
206 |
+
"""
|
207 |
+
x, row_attn = self.row_self_attention(
|
208 |
+
x,
|
209 |
+
self_attn_mask=self_attn_mask,
|
210 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
211 |
+
)
|
212 |
+
x, column_attn = self.column_self_attention(
|
213 |
+
x,
|
214 |
+
self_attn_mask=self_attn_mask,
|
215 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
216 |
+
)
|
217 |
+
x = self.feed_forward_layer(x)
|
218 |
+
if need_head_weights:
|
219 |
+
return x, column_attn, row_attn
|
220 |
+
else:
|
221 |
+
return x
|
222 |
+
|
223 |
+
|
224 |
+
class LearnedPositionalEmbedding(nn.Embedding):
|
225 |
+
"""
|
226 |
+
This module learns positional embeddings up to a fixed maximum size.
|
227 |
+
Padding ids are ignored by either offsetting based on padding_idx
|
228 |
+
or by setting padding_idx to None and ensuring that the appropriate
|
229 |
+
position ids are passed to the forward function.
|
230 |
+
"""
|
231 |
+
|
232 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
233 |
+
if padding_idx is not None:
|
234 |
+
num_embeddings_ = num_embeddings + padding_idx + 1
|
235 |
+
else:
|
236 |
+
num_embeddings_ = num_embeddings
|
237 |
+
super().__init__(num_embeddings_, embedding_dim, padding_idx)
|
238 |
+
self.max_positions = num_embeddings
|
239 |
+
|
240 |
+
def forward(self, input: torch.Tensor):
|
241 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
242 |
+
if input.size(1) > self.max_positions:
|
243 |
+
raise ValueError(
|
244 |
+
f"Sequence length {input.size(1)} above maximum "
|
245 |
+
f" sequence length of {self.max_positions}"
|
246 |
+
)
|
247 |
+
mask = input.ne(self.padding_idx).int()
|
248 |
+
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
|
249 |
+
return F.embedding(
|
250 |
+
positions,
|
251 |
+
self.weight,
|
252 |
+
self.padding_idx,
|
253 |
+
self.max_norm,
|
254 |
+
self.norm_type,
|
255 |
+
self.scale_grad_by_freq,
|
256 |
+
self.sparse,
|
257 |
+
)
|
258 |
+
|
259 |
+
|
260 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
261 |
+
def __init__(self, embed_dim, padding_idx, learned=False):
|
262 |
+
super().__init__()
|
263 |
+
self.embed_dim = embed_dim
|
264 |
+
self.padding_idx = padding_idx
|
265 |
+
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
266 |
+
self.weights = None
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
bsz, seq_len = x.shape
|
270 |
+
max_pos = self.padding_idx + 1 + seq_len
|
271 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
272 |
+
self.weights = self.get_embedding(max_pos)
|
273 |
+
self.weights = self.weights.type_as(self._float_tensor)
|
274 |
+
|
275 |
+
positions = self.make_positions(x)
|
276 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
277 |
+
|
278 |
+
def make_positions(self, x):
|
279 |
+
mask = x.ne(self.padding_idx)
|
280 |
+
range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
|
281 |
+
positions = range_buf.expand_as(x)
|
282 |
+
return positions * mask.long() + self.padding_idx * (1 - mask.long())
|
283 |
+
|
284 |
+
def get_embedding(self, num_embeddings):
|
285 |
+
half_dim = self.embed_dim // 2
|
286 |
+
emb = math.log(10000) / (half_dim - 1)
|
287 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
288 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
289 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
290 |
+
if self.embed_dim % 2 == 1:
|
291 |
+
# zero pad
|
292 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
293 |
+
if self.padding_idx is not None:
|
294 |
+
emb[self.padding_idx, :] = 0
|
295 |
+
return emb
|
296 |
+
|
297 |
+
|
298 |
+
class RobertaLMHead(nn.Module):
|
299 |
+
"""Head for masked language modeling."""
|
300 |
+
|
301 |
+
def __init__(self, embed_dim, output_dim, weight):
|
302 |
+
super().__init__()
|
303 |
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
304 |
+
self.layer_norm = ESM1bLayerNorm(embed_dim)
|
305 |
+
self.weight = weight
|
306 |
+
self.bias = nn.Parameter(torch.zeros(output_dim))
|
307 |
+
|
308 |
+
def forward(self, features):
|
309 |
+
x = self.dense(features)
|
310 |
+
x = gelu(x)
|
311 |
+
x = self.layer_norm(x)
|
312 |
+
# project back to size of vocabulary with bias
|
313 |
+
x = F.linear(x, self.weight) + self.bias
|
314 |
+
return x
|
315 |
+
|
316 |
+
|
317 |
+
class ContactPredictionHead(nn.Module):
|
318 |
+
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
319 |
+
|
320 |
+
def __init__(
|
321 |
+
self,
|
322 |
+
in_features: int,
|
323 |
+
prepend_bos: bool,
|
324 |
+
append_eos: bool,
|
325 |
+
bias=True,
|
326 |
+
eos_idx: Optional[int] = None,
|
327 |
+
):
|
328 |
+
super().__init__()
|
329 |
+
self.in_features = in_features
|
330 |
+
self.prepend_bos = prepend_bos
|
331 |
+
self.append_eos = append_eos
|
332 |
+
if append_eos and eos_idx is None:
|
333 |
+
raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
|
334 |
+
self.eos_idx = eos_idx
|
335 |
+
self.regression = nn.Linear(in_features, 1, bias)
|
336 |
+
self.activation = nn.Sigmoid()
|
337 |
+
|
338 |
+
def forward(self, tokens, attentions):
|
339 |
+
# remove eos token attentions
|
340 |
+
if self.append_eos:
|
341 |
+
eos_mask = tokens.ne(self.eos_idx).to(attentions)
|
342 |
+
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
|
343 |
+
attentions = attentions * eos_mask[:, None, None, :, :]
|
344 |
+
attentions = attentions[..., :-1, :-1]
|
345 |
+
# remove cls token attentions
|
346 |
+
if self.prepend_bos:
|
347 |
+
attentions = attentions[..., 1:, 1:]
|
348 |
+
batch_size, layers, heads, seqlen, _ = attentions.size()
|
349 |
+
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
|
350 |
+
|
351 |
+
# features: B x C x T x T
|
352 |
+
attentions = attentions.to(
|
353 |
+
self.regression.weight.device
|
354 |
+
) # attentions always float32, may need to convert to float16
|
355 |
+
attentions = apc(symmetrize(attentions))
|
356 |
+
attentions = attentions.permute(0, 2, 3, 1)
|
357 |
+
return self.activation(self.regression(attentions).squeeze(3))
|
358 |
+
|
359 |
+
|
360 |
+
class NormalizedResidualBlock(nn.Module):
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
layer: nn.Module,
|
364 |
+
embedding_dim: int,
|
365 |
+
dropout: float = 0.1,
|
366 |
+
):
|
367 |
+
super().__init__()
|
368 |
+
self.embedding_dim = embedding_dim
|
369 |
+
|
370 |
+
self.layer = layer
|
371 |
+
self.dropout_module = nn.Dropout(
|
372 |
+
dropout,
|
373 |
+
)
|
374 |
+
self.layer_norm = ESM1bLayerNorm(self.embedding_dim)
|
375 |
+
|
376 |
+
def forward(self, x, *args, **kwargs):
|
377 |
+
residual = x
|
378 |
+
x = self.layer_norm(x)
|
379 |
+
outputs = self.layer(x, *args, **kwargs)
|
380 |
+
if isinstance(outputs, tuple):
|
381 |
+
x, *out = outputs
|
382 |
+
else:
|
383 |
+
x = outputs
|
384 |
+
out = None
|
385 |
+
|
386 |
+
x = self.dropout_module(x)
|
387 |
+
x = residual + x
|
388 |
+
|
389 |
+
if out is not None:
|
390 |
+
return (x,) + tuple(out)
|
391 |
+
else:
|
392 |
+
return x
|
393 |
+
|
394 |
+
|
395 |
+
class FeedForwardNetwork(nn.Module):
|
396 |
+
def __init__(
|
397 |
+
self,
|
398 |
+
embedding_dim: int,
|
399 |
+
ffn_embedding_dim: int,
|
400 |
+
activation_dropout: float = 0.1,
|
401 |
+
max_tokens_per_msa: int = 2**14,
|
402 |
+
):
|
403 |
+
super().__init__()
|
404 |
+
self.embedding_dim = embedding_dim
|
405 |
+
self.ffn_embedding_dim = ffn_embedding_dim
|
406 |
+
self.max_tokens_per_msa = max_tokens_per_msa
|
407 |
+
self.activation_fn = nn.GELU()
|
408 |
+
self.activation_dropout_module = nn.Dropout(
|
409 |
+
activation_dropout,
|
410 |
+
)
|
411 |
+
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
|
412 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
|
413 |
+
|
414 |
+
def forward(self, x):
|
415 |
+
x = self.activation_fn(self.fc1(x))
|
416 |
+
x = self.activation_dropout_module(x)
|
417 |
+
x = self.fc2(x)
|
418 |
+
return x
|
esm/multihead_attention.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import Dict, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from torch.nn import Parameter
|
13 |
+
from esm.rotary_embedding import RotaryEmbedding
|
14 |
+
|
15 |
+
import uuid
|
16 |
+
|
17 |
+
|
18 |
+
def utils_softmax(x, dim: int, onnx_trace: bool = False):
|
19 |
+
if onnx_trace:
|
20 |
+
return F.softmax(x.float(), dim=dim)
|
21 |
+
else:
|
22 |
+
return F.softmax(x, dim=dim, dtype=torch.float32)
|
23 |
+
|
24 |
+
|
25 |
+
class FairseqIncrementalState(object):
|
26 |
+
def __init__(self, *args, **kwargs):
|
27 |
+
super().__init__(*args, **kwargs)
|
28 |
+
self.init_incremental_state()
|
29 |
+
|
30 |
+
def init_incremental_state(self):
|
31 |
+
self._incremental_state_id = str(uuid.uuid4())
|
32 |
+
|
33 |
+
def _get_full_incremental_state_key(self, key: str) -> str:
|
34 |
+
return "{}.{}".format(self._incremental_state_id, key)
|
35 |
+
|
36 |
+
def get_incremental_state(
|
37 |
+
self,
|
38 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
39 |
+
key: str,
|
40 |
+
) -> Optional[Dict[str, Optional[Tensor]]]:
|
41 |
+
"""Helper for getting incremental state for an nn.Module."""
|
42 |
+
full_key = self._get_full_incremental_state_key(key)
|
43 |
+
if incremental_state is None or full_key not in incremental_state:
|
44 |
+
return None
|
45 |
+
return incremental_state[full_key]
|
46 |
+
|
47 |
+
def set_incremental_state(
|
48 |
+
self,
|
49 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
50 |
+
key: str,
|
51 |
+
value: Dict[str, Optional[Tensor]],
|
52 |
+
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
53 |
+
"""Helper for setting incremental state for an nn.Module."""
|
54 |
+
if incremental_state is not None:
|
55 |
+
full_key = self._get_full_incremental_state_key(key)
|
56 |
+
incremental_state[full_key] = value
|
57 |
+
return incremental_state
|
58 |
+
|
59 |
+
|
60 |
+
def with_incremental_state(cls):
|
61 |
+
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
62 |
+
b for b in cls.__bases__ if b != FairseqIncrementalState
|
63 |
+
)
|
64 |
+
return cls
|
65 |
+
|
66 |
+
|
67 |
+
@with_incremental_state
|
68 |
+
class MultiheadAttention(nn.Module):
|
69 |
+
"""Multi-headed attention.
|
70 |
+
|
71 |
+
See "Attention Is All You Need" for more details.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
embed_dim,
|
77 |
+
num_heads,
|
78 |
+
kdim=None,
|
79 |
+
vdim=None,
|
80 |
+
dropout=0.0,
|
81 |
+
bias=True,
|
82 |
+
add_bias_kv: bool = False,
|
83 |
+
add_zero_attn: bool = False,
|
84 |
+
self_attention: bool = False,
|
85 |
+
encoder_decoder_attention: bool = False,
|
86 |
+
use_rotary_embeddings: bool = False,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.embed_dim = embed_dim
|
90 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
91 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
92 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
93 |
+
|
94 |
+
self.num_heads = num_heads
|
95 |
+
self.dropout = dropout
|
96 |
+
self.head_dim = embed_dim // num_heads
|
97 |
+
assert (
|
98 |
+
self.head_dim * num_heads == self.embed_dim
|
99 |
+
), "embed_dim must be divisible by num_heads"
|
100 |
+
self.scaling = self.head_dim**-0.5
|
101 |
+
|
102 |
+
self.self_attention = self_attention
|
103 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
104 |
+
|
105 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
106 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
107 |
+
)
|
108 |
+
|
109 |
+
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
|
110 |
+
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
|
111 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
112 |
+
|
113 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
114 |
+
|
115 |
+
if add_bias_kv:
|
116 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
117 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
118 |
+
else:
|
119 |
+
self.bias_k = self.bias_v = None
|
120 |
+
|
121 |
+
self.add_zero_attn = add_zero_attn
|
122 |
+
|
123 |
+
self.reset_parameters()
|
124 |
+
|
125 |
+
self.onnx_trace = False
|
126 |
+
self.rot_emb = None
|
127 |
+
if use_rotary_embeddings:
|
128 |
+
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
|
129 |
+
|
130 |
+
self.enable_torch_version = False
|
131 |
+
if hasattr(F, "multi_head_attention_forward"):
|
132 |
+
self.enable_torch_version = True
|
133 |
+
else:
|
134 |
+
self.enable_torch_version = False
|
135 |
+
|
136 |
+
def prepare_for_onnx_export_(self):
|
137 |
+
self.onnx_trace = True
|
138 |
+
|
139 |
+
def reset_parameters(self):
|
140 |
+
if self.qkv_same_dim:
|
141 |
+
# Empirically observed the convergence to be much better with
|
142 |
+
# the scaled initialization
|
143 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
144 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
145 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
146 |
+
else:
|
147 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
148 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
149 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
150 |
+
|
151 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
152 |
+
if self.out_proj.bias is not None:
|
153 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
154 |
+
if self.bias_k is not None:
|
155 |
+
nn.init.xavier_normal_(self.bias_k)
|
156 |
+
if self.bias_v is not None:
|
157 |
+
nn.init.xavier_normal_(self.bias_v)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
query,
|
162 |
+
key: Optional[Tensor],
|
163 |
+
value: Optional[Tensor],
|
164 |
+
key_padding_mask: Optional[Tensor] = None,
|
165 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
166 |
+
need_weights: bool = True,
|
167 |
+
static_kv: bool = False,
|
168 |
+
attn_mask: Optional[Tensor] = None,
|
169 |
+
before_softmax: bool = False,
|
170 |
+
need_head_weights: bool = False,
|
171 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
172 |
+
"""Input shape: Time x Batch x Channel
|
173 |
+
|
174 |
+
Args:
|
175 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
176 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
177 |
+
padding elements are indicated by 1s.
|
178 |
+
need_weights (bool, optional): return the attention weights,
|
179 |
+
averaged over heads (default: False).
|
180 |
+
attn_mask (ByteTensor, optional): typically used to
|
181 |
+
implement causal attention, where the mask prevents the
|
182 |
+
attention from looking forward in time (default: None).
|
183 |
+
before_softmax (bool, optional): return the raw attention
|
184 |
+
weights and values before the attention softmax.
|
185 |
+
need_head_weights (bool, optional): return the attention
|
186 |
+
weights for each head. Implies *need_weights*. Default:
|
187 |
+
return the average attention weights over all heads.
|
188 |
+
"""
|
189 |
+
if need_head_weights:
|
190 |
+
need_weights = True
|
191 |
+
|
192 |
+
tgt_len, bsz, embed_dim = query.size()
|
193 |
+
assert embed_dim == self.embed_dim
|
194 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
195 |
+
|
196 |
+
if (
|
197 |
+
not self.rot_emb
|
198 |
+
and self.enable_torch_version
|
199 |
+
and not self.onnx_trace
|
200 |
+
and incremental_state is None
|
201 |
+
and not static_kv
|
202 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
203 |
+
# treats bias in linear module as method.
|
204 |
+
and not torch.jit.is_scripting()
|
205 |
+
and not need_head_weights
|
206 |
+
):
|
207 |
+
assert key is not None and value is not None
|
208 |
+
return F.multi_head_attention_forward(
|
209 |
+
query,
|
210 |
+
key,
|
211 |
+
value,
|
212 |
+
self.embed_dim,
|
213 |
+
self.num_heads,
|
214 |
+
torch.empty([0]),
|
215 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
216 |
+
self.bias_k,
|
217 |
+
self.bias_v,
|
218 |
+
self.add_zero_attn,
|
219 |
+
self.dropout,
|
220 |
+
self.out_proj.weight,
|
221 |
+
self.out_proj.bias,
|
222 |
+
self.training,
|
223 |
+
key_padding_mask,
|
224 |
+
need_weights,
|
225 |
+
attn_mask,
|
226 |
+
use_separate_proj_weight=True,
|
227 |
+
q_proj_weight=self.q_proj.weight,
|
228 |
+
k_proj_weight=self.k_proj.weight,
|
229 |
+
v_proj_weight=self.v_proj.weight,
|
230 |
+
)
|
231 |
+
if incremental_state is not None:
|
232 |
+
saved_state = self._get_input_buffer(incremental_state)
|
233 |
+
if saved_state is not None and "prev_key" in saved_state:
|
234 |
+
# previous time steps are cached - no need to recompute
|
235 |
+
# key and value if they are static
|
236 |
+
if static_kv:
|
237 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
238 |
+
key = value = None
|
239 |
+
else:
|
240 |
+
saved_state = None
|
241 |
+
|
242 |
+
if self.self_attention:
|
243 |
+
q = self.q_proj(query)
|
244 |
+
k = self.k_proj(query)
|
245 |
+
v = self.v_proj(query)
|
246 |
+
elif self.encoder_decoder_attention:
|
247 |
+
# encoder-decoder attention
|
248 |
+
q = self.q_proj(query)
|
249 |
+
if key is None:
|
250 |
+
assert value is None
|
251 |
+
k = v = None
|
252 |
+
else:
|
253 |
+
k = self.k_proj(key)
|
254 |
+
v = self.v_proj(key)
|
255 |
+
|
256 |
+
else:
|
257 |
+
assert key is not None and value is not None
|
258 |
+
q = self.q_proj(query)
|
259 |
+
k = self.k_proj(key)
|
260 |
+
v = self.v_proj(value)
|
261 |
+
q *= self.scaling
|
262 |
+
|
263 |
+
if self.bias_k is not None:
|
264 |
+
assert self.bias_v is not None
|
265 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
266 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
267 |
+
if attn_mask is not None:
|
268 |
+
attn_mask = torch.cat(
|
269 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
270 |
+
)
|
271 |
+
if key_padding_mask is not None:
|
272 |
+
key_padding_mask = torch.cat(
|
273 |
+
[
|
274 |
+
key_padding_mask,
|
275 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
276 |
+
],
|
277 |
+
dim=1,
|
278 |
+
)
|
279 |
+
|
280 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
281 |
+
if k is not None:
|
282 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
283 |
+
if v is not None:
|
284 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
285 |
+
|
286 |
+
if saved_state is not None:
|
287 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
288 |
+
if "prev_key" in saved_state:
|
289 |
+
_prev_key = saved_state["prev_key"]
|
290 |
+
assert _prev_key is not None
|
291 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
292 |
+
if static_kv:
|
293 |
+
k = prev_key
|
294 |
+
else:
|
295 |
+
assert k is not None
|
296 |
+
k = torch.cat([prev_key, k], dim=1)
|
297 |
+
if "prev_value" in saved_state:
|
298 |
+
_prev_value = saved_state["prev_value"]
|
299 |
+
assert _prev_value is not None
|
300 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
301 |
+
if static_kv:
|
302 |
+
v = prev_value
|
303 |
+
else:
|
304 |
+
assert v is not None
|
305 |
+
v = torch.cat([prev_value, v], dim=1)
|
306 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
307 |
+
if "prev_key_padding_mask" in saved_state:
|
308 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
309 |
+
assert k is not None and v is not None
|
310 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
311 |
+
key_padding_mask=key_padding_mask,
|
312 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
313 |
+
batch_size=bsz,
|
314 |
+
src_len=k.size(1),
|
315 |
+
static_kv=static_kv,
|
316 |
+
)
|
317 |
+
|
318 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
319 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
320 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
321 |
+
# In this branch incremental_state is never None
|
322 |
+
assert incremental_state is not None
|
323 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
324 |
+
assert k is not None
|
325 |
+
src_len = k.size(1)
|
326 |
+
|
327 |
+
# This is part of a workaround to get around fork/join parallelism
|
328 |
+
# not supporting Optional types.
|
329 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
330 |
+
key_padding_mask = None
|
331 |
+
|
332 |
+
if key_padding_mask is not None:
|
333 |
+
assert key_padding_mask.size(0) == bsz
|
334 |
+
assert key_padding_mask.size(1) == src_len
|
335 |
+
|
336 |
+
if self.add_zero_attn:
|
337 |
+
assert v is not None
|
338 |
+
src_len += 1
|
339 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
340 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
341 |
+
if attn_mask is not None:
|
342 |
+
attn_mask = torch.cat(
|
343 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
344 |
+
)
|
345 |
+
if key_padding_mask is not None:
|
346 |
+
key_padding_mask = torch.cat(
|
347 |
+
[
|
348 |
+
key_padding_mask,
|
349 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
350 |
+
],
|
351 |
+
dim=1,
|
352 |
+
)
|
353 |
+
|
354 |
+
if self.rot_emb:
|
355 |
+
q, k = self.rot_emb(q, k)
|
356 |
+
|
357 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
358 |
+
attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
359 |
+
|
360 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
361 |
+
|
362 |
+
if attn_mask is not None:
|
363 |
+
attn_mask = attn_mask.unsqueeze(0)
|
364 |
+
if self.onnx_trace:
|
365 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
366 |
+
attn_weights += attn_mask
|
367 |
+
|
368 |
+
if key_padding_mask is not None:
|
369 |
+
# don't attend to padding symbols
|
370 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
371 |
+
attn_weights = attn_weights.masked_fill(
|
372 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
|
373 |
+
)
|
374 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
375 |
+
|
376 |
+
if before_softmax:
|
377 |
+
return attn_weights, v
|
378 |
+
|
379 |
+
attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
|
380 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
381 |
+
attn_probs = F.dropout(
|
382 |
+
attn_weights_float.type_as(attn_weights),
|
383 |
+
p=self.dropout,
|
384 |
+
training=self.training,
|
385 |
+
)
|
386 |
+
assert v is not None
|
387 |
+
attn = torch.bmm(attn_probs, v)
|
388 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
389 |
+
if self.onnx_trace and attn.size(1) == 1:
|
390 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
391 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
392 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
393 |
+
else:
|
394 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
395 |
+
attn = self.out_proj(attn)
|
396 |
+
attn_weights: Optional[Tensor] = None
|
397 |
+
if need_weights:
|
398 |
+
attn_weights = attn_weights_float.view(
|
399 |
+
bsz, self.num_heads, tgt_len, src_len
|
400 |
+
).type_as(attn).transpose(1, 0)
|
401 |
+
if not need_head_weights:
|
402 |
+
# average attention weights over heads
|
403 |
+
attn_weights = attn_weights.mean(dim=0)
|
404 |
+
|
405 |
+
return attn, attn_weights
|
406 |
+
|
407 |
+
@staticmethod
|
408 |
+
def _append_prev_key_padding_mask(
|
409 |
+
key_padding_mask: Optional[Tensor],
|
410 |
+
prev_key_padding_mask: Optional[Tensor],
|
411 |
+
batch_size: int,
|
412 |
+
src_len: int,
|
413 |
+
static_kv: bool,
|
414 |
+
) -> Optional[Tensor]:
|
415 |
+
# saved key padding masks have shape (bsz, seq_len)
|
416 |
+
if prev_key_padding_mask is not None and static_kv:
|
417 |
+
new_key_padding_mask = prev_key_padding_mask
|
418 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
419 |
+
new_key_padding_mask = torch.cat(
|
420 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
421 |
+
)
|
422 |
+
# During incremental decoding, as the padding token enters and
|
423 |
+
# leaves the frame, there will be a time when prev or current
|
424 |
+
# is None
|
425 |
+
elif prev_key_padding_mask is not None:
|
426 |
+
filler = torch.zeros(
|
427 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
428 |
+
device=prev_key_padding_mask.device,
|
429 |
+
)
|
430 |
+
new_key_padding_mask = torch.cat(
|
431 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
432 |
+
)
|
433 |
+
elif key_padding_mask is not None:
|
434 |
+
filler = torch.zeros(
|
435 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
436 |
+
device=key_padding_mask.device,
|
437 |
+
)
|
438 |
+
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
439 |
+
else:
|
440 |
+
new_key_padding_mask = prev_key_padding_mask
|
441 |
+
return new_key_padding_mask
|
442 |
+
|
443 |
+
@torch.jit.export
|
444 |
+
def reorder_incremental_state(
|
445 |
+
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
|
446 |
+
):
|
447 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
448 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
449 |
+
if input_buffer is not None:
|
450 |
+
for k in input_buffer.keys():
|
451 |
+
input_buffer_k = input_buffer[k]
|
452 |
+
if input_buffer_k is not None:
|
453 |
+
if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
|
454 |
+
0
|
455 |
+
):
|
456 |
+
break
|
457 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
458 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
459 |
+
return incremental_state
|
460 |
+
|
461 |
+
def _get_input_buffer(
|
462 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
463 |
+
) -> Dict[str, Optional[Tensor]]:
|
464 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
465 |
+
if result is not None:
|
466 |
+
return result
|
467 |
+
else:
|
468 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
469 |
+
return empty_result
|
470 |
+
|
471 |
+
def _set_input_buffer(
|
472 |
+
self,
|
473 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
474 |
+
buffer: Dict[str, Optional[Tensor]],
|
475 |
+
):
|
476 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
477 |
+
|
478 |
+
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
|
479 |
+
return attn_weights
|
480 |
+
|
481 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
482 |
+
prefix = name + "." if name != "" else ""
|
483 |
+
items_to_add = {}
|
484 |
+
keys_to_remove = []
|
485 |
+
for k in state_dict.keys():
|
486 |
+
if k.endswith(prefix + "in_proj_weight"):
|
487 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
488 |
+
dim = int(state_dict[k].shape[0] / 3)
|
489 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
490 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
491 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
492 |
+
|
493 |
+
keys_to_remove.append(k)
|
494 |
+
|
495 |
+
k_bias = prefix + "in_proj_bias"
|
496 |
+
if k_bias in state_dict.keys():
|
497 |
+
dim = int(state_dict[k].shape[0] / 3)
|
498 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
499 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
|
500 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
501 |
+
|
502 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
503 |
+
|
504 |
+
for k in keys_to_remove:
|
505 |
+
del state_dict[k]
|
506 |
+
|
507 |
+
for key, value in items_to_add.items():
|
508 |
+
state_dict[key] = value
|
esm/pretrained.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import re
|
7 |
+
import urllib
|
8 |
+
import warnings
|
9 |
+
from argparse import Namespace
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import esm
|
15 |
+
from esm.model.esm2 import ESM2
|
16 |
+
|
17 |
+
|
18 |
+
def _has_regression_weights(model_name):
|
19 |
+
"""Return whether we expect / require regression weights;
|
20 |
+
Right now that is all models except ESM-1v, ESM-IF, and partially trained ESM2 models"""
|
21 |
+
return not ("esm1v" in model_name or "esm_if" in model_name or "270K" in model_name or "500K" in model_name)
|
22 |
+
|
23 |
+
|
24 |
+
def load_model_and_alphabet(model_name):
|
25 |
+
if model_name.endswith(".pt"): # treat as filepath
|
26 |
+
return load_model_and_alphabet_local(model_name)
|
27 |
+
else:
|
28 |
+
return load_model_and_alphabet_hub(model_name)
|
29 |
+
|
30 |
+
|
31 |
+
def load_hub_workaround(url):
|
32 |
+
try:
|
33 |
+
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
|
34 |
+
except RuntimeError:
|
35 |
+
# Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106
|
36 |
+
fn = Path(url).name
|
37 |
+
data = torch.load(
|
38 |
+
f"{torch.hub.get_dir()}/checkpoints/{fn}",
|
39 |
+
map_location="cpu",
|
40 |
+
)
|
41 |
+
except urllib.error.HTTPError as e:
|
42 |
+
raise Exception(f"Could not load {url}, check if you specified a correct model name?")
|
43 |
+
return data
|
44 |
+
|
45 |
+
|
46 |
+
def load_regression_hub(model_name):
|
47 |
+
url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt"
|
48 |
+
regression_data = load_hub_workaround(url)
|
49 |
+
return regression_data
|
50 |
+
|
51 |
+
|
52 |
+
def _download_model_and_regression_data(model_name):
|
53 |
+
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
|
54 |
+
model_data = load_hub_workaround(url)
|
55 |
+
if _has_regression_weights(model_name):
|
56 |
+
regression_data = load_regression_hub(model_name)
|
57 |
+
else:
|
58 |
+
regression_data = None
|
59 |
+
return model_data, regression_data
|
60 |
+
|
61 |
+
|
62 |
+
def load_model_and_alphabet_hub(model_name):
|
63 |
+
model_data, regression_data = _download_model_and_regression_data(model_name)
|
64 |
+
return load_model_and_alphabet_core(model_name, model_data, regression_data)
|
65 |
+
|
66 |
+
|
67 |
+
def load_model_and_alphabet_local(model_location):
|
68 |
+
"""Load from local path. The regression weights need to be co-located"""
|
69 |
+
model_location = Path(model_location)
|
70 |
+
model_data = torch.load(str(model_location), map_location="cpu")
|
71 |
+
model_name = model_location.stem
|
72 |
+
if _has_regression_weights(model_name):
|
73 |
+
regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt"
|
74 |
+
regression_data = torch.load(regression_location, map_location="cpu")
|
75 |
+
else:
|
76 |
+
regression_data = None
|
77 |
+
return load_model_and_alphabet_core(model_name, model_data, regression_data)
|
78 |
+
|
79 |
+
|
80 |
+
def has_emb_layer_norm_before(model_state):
|
81 |
+
"""Determine whether layer norm needs to be applied before the encoder"""
|
82 |
+
return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items())
|
83 |
+
|
84 |
+
|
85 |
+
def _load_model_and_alphabet_core_v1(model_data):
|
86 |
+
import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here
|
87 |
+
|
88 |
+
alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)
|
89 |
+
|
90 |
+
if model_data["args"].arch == "roberta_large":
|
91 |
+
# upgrade state dict
|
92 |
+
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
|
93 |
+
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
|
94 |
+
prs2 = lambda s: "".join(
|
95 |
+
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
|
96 |
+
)
|
97 |
+
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
|
98 |
+
model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()}
|
99 |
+
model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop
|
100 |
+
model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state)
|
101 |
+
model_type = esm.ProteinBertModel
|
102 |
+
|
103 |
+
elif model_data["args"].arch == "protein_bert_base":
|
104 |
+
|
105 |
+
# upgrade state dict
|
106 |
+
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
|
107 |
+
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
|
108 |
+
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
|
109 |
+
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}
|
110 |
+
model_type = esm.ProteinBertModel
|
111 |
+
elif model_data["args"].arch == "msa_transformer":
|
112 |
+
|
113 |
+
# upgrade state dict
|
114 |
+
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
|
115 |
+
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
|
116 |
+
prs2 = lambda s: "".join(
|
117 |
+
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
|
118 |
+
)
|
119 |
+
prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row")
|
120 |
+
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
|
121 |
+
model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()}
|
122 |
+
if model_args.get("embed_positions_msa", False):
|
123 |
+
emb_dim = model_state["msa_position_embedding"].size(-1)
|
124 |
+
model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1
|
125 |
+
|
126 |
+
model_type = esm.MSATransformer
|
127 |
+
|
128 |
+
elif "invariant_gvp" in model_data["args"].arch:
|
129 |
+
import esm.inverse_folding
|
130 |
+
|
131 |
+
model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel
|
132 |
+
model_args = vars(model_data["args"]) # convert Namespace -> dict
|
133 |
+
|
134 |
+
def update_name(s):
|
135 |
+
# Map the module names in checkpoints trained with internal code to
|
136 |
+
# the updated module names in open source code
|
137 |
+
s = s.replace("W_v", "embed_graph.embed_node")
|
138 |
+
s = s.replace("W_e", "embed_graph.embed_edge")
|
139 |
+
s = s.replace("embed_scores.0", "embed_confidence")
|
140 |
+
s = s.replace("embed_score.", "embed_graph.embed_confidence.")
|
141 |
+
s = s.replace("seq_logits_projection.", "")
|
142 |
+
s = s.replace("embed_ingraham_features", "embed_dihedrals")
|
143 |
+
s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output")
|
144 |
+
s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features")
|
145 |
+
return s
|
146 |
+
|
147 |
+
model_state = {
|
148 |
+
update_name(sname): svalue
|
149 |
+
for sname, svalue in model_data["model"].items()
|
150 |
+
if "version" not in sname
|
151 |
+
}
|
152 |
+
|
153 |
+
else:
|
154 |
+
raise ValueError("Unknown architecture selected")
|
155 |
+
|
156 |
+
model = model_type(
|
157 |
+
Namespace(**model_args),
|
158 |
+
alphabet,
|
159 |
+
)
|
160 |
+
|
161 |
+
return model, alphabet, model_state
|
162 |
+
|
163 |
+
|
164 |
+
def _load_model_and_alphabet_core_v2(model_data):
|
165 |
+
def upgrade_state_dict(state_dict):
|
166 |
+
"""Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
|
167 |
+
prefixes = ["encoder.sentence_encoder.", "encoder."]
|
168 |
+
pattern = re.compile("^" + "|".join(prefixes))
|
169 |
+
state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
|
170 |
+
return state_dict
|
171 |
+
|
172 |
+
cfg = model_data["cfg"]["model"]
|
173 |
+
state_dict = model_data["model"]
|
174 |
+
state_dict = upgrade_state_dict(state_dict)
|
175 |
+
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
|
176 |
+
model = ESM2(
|
177 |
+
num_layers=cfg.encoder_layers,
|
178 |
+
embed_dim=cfg.encoder_embed_dim,
|
179 |
+
attention_heads=cfg.encoder_attention_heads,
|
180 |
+
alphabet=alphabet,
|
181 |
+
token_dropout=cfg.token_dropout,
|
182 |
+
)
|
183 |
+
return model, alphabet, state_dict
|
184 |
+
|
185 |
+
|
186 |
+
def load_model_and_alphabet_core(model_name, model_data, regression_data=None):
|
187 |
+
if regression_data is not None:
|
188 |
+
model_data["model"].update(regression_data["model"])
|
189 |
+
|
190 |
+
if model_name.startswith("esm2"):
|
191 |
+
model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data)
|
192 |
+
else:
|
193 |
+
model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data)
|
194 |
+
|
195 |
+
expected_keys = set(model.state_dict().keys())
|
196 |
+
found_keys = set(model_state.keys())
|
197 |
+
|
198 |
+
if regression_data is None:
|
199 |
+
expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"}
|
200 |
+
error_msgs = []
|
201 |
+
missing = (expected_keys - found_keys) - expected_missing
|
202 |
+
if missing:
|
203 |
+
error_msgs.append(f"Missing key(s) in state_dict: {missing}.")
|
204 |
+
unexpected = found_keys - expected_keys
|
205 |
+
if unexpected:
|
206 |
+
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.")
|
207 |
+
|
208 |
+
if error_msgs:
|
209 |
+
raise RuntimeError(
|
210 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
211 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
212 |
+
)
|
213 |
+
)
|
214 |
+
if expected_missing - found_keys:
|
215 |
+
warnings.warn(
|
216 |
+
"Regression weights not found, predicting contacts will not produce correct results."
|
217 |
+
)
|
218 |
+
|
219 |
+
model.load_state_dict(model_state, strict=regression_data is not None)
|
220 |
+
|
221 |
+
return model, alphabet
|
222 |
+
|
223 |
+
|
224 |
+
def esm1_t34_670M_UR50S():
|
225 |
+
"""34 layer transformer model with 670M params, trained on Uniref50 Sparse.
|
226 |
+
|
227 |
+
Returns a tuple of (Model, Alphabet).
|
228 |
+
"""
|
229 |
+
return load_model_and_alphabet_hub("esm1_t34_670M_UR50S")
|
230 |
+
|
231 |
+
|
232 |
+
def esm1_t34_670M_UR50D():
|
233 |
+
"""34 layer transformer model with 670M params, trained on Uniref50 Dense.
|
234 |
+
|
235 |
+
Returns a tuple of (Model, Alphabet).
|
236 |
+
"""
|
237 |
+
return load_model_and_alphabet_hub("esm1_t34_670M_UR50D")
|
238 |
+
|
239 |
+
|
240 |
+
def esm1_t34_670M_UR100():
|
241 |
+
"""34 layer transformer model with 670M params, trained on Uniref100.
|
242 |
+
|
243 |
+
Returns a tuple of (Model, Alphabet).
|
244 |
+
"""
|
245 |
+
return load_model_and_alphabet_hub("esm1_t34_670M_UR100")
|
246 |
+
|
247 |
+
|
248 |
+
def esm1_t12_85M_UR50S():
|
249 |
+
"""12 layer transformer model with 85M params, trained on Uniref50 Sparse.
|
250 |
+
|
251 |
+
Returns a tuple of (Model, Alphabet).
|
252 |
+
"""
|
253 |
+
return load_model_and_alphabet_hub("esm1_t12_85M_UR50S")
|
254 |
+
|
255 |
+
|
256 |
+
def esm1_t6_43M_UR50S():
|
257 |
+
"""6 layer transformer model with 43M params, trained on Uniref50 Sparse.
|
258 |
+
|
259 |
+
Returns a tuple of (Model, Alphabet).
|
260 |
+
"""
|
261 |
+
return load_model_and_alphabet_hub("esm1_t6_43M_UR50S")
|
262 |
+
|
263 |
+
|
264 |
+
def esm1b_t33_650M_UR50S():
|
265 |
+
"""33 layer transformer model with 650M params, trained on Uniref50 Sparse.
|
266 |
+
This is our best performing model, which will be described in a future publication.
|
267 |
+
|
268 |
+
Returns a tuple of (Model, Alphabet).
|
269 |
+
"""
|
270 |
+
return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S")
|
271 |
+
|
272 |
+
|
273 |
+
def esm_msa1_t12_100M_UR50S():
|
274 |
+
warnings.warn(
|
275 |
+
"This model had a minor bug in the positional embeddings, "
|
276 |
+
"please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()",
|
277 |
+
)
|
278 |
+
return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S")
|
279 |
+
|
280 |
+
|
281 |
+
def esm_msa1b_t12_100M_UR50S():
|
282 |
+
return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S")
|
283 |
+
|
284 |
+
|
285 |
+
def esm1v_t33_650M_UR90S():
|
286 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
287 |
+
This is model 1 of a 5 model ensemble.
|
288 |
+
|
289 |
+
Returns a tuple of (Model, Alphabet).
|
290 |
+
"""
|
291 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
|
292 |
+
|
293 |
+
|
294 |
+
def esm1v_t33_650M_UR90S_1():
|
295 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
296 |
+
This is model 1 of a 5 model ensemble.
|
297 |
+
|
298 |
+
Returns a tuple of (Model, Alphabet).
|
299 |
+
"""
|
300 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
|
301 |
+
|
302 |
+
|
303 |
+
def esm1v_t33_650M_UR90S_2():
|
304 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
305 |
+
This is model 2 of a 5 model ensemble.
|
306 |
+
|
307 |
+
Returns a tuple of (Model, Alphabet).
|
308 |
+
"""
|
309 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2")
|
310 |
+
|
311 |
+
|
312 |
+
def esm1v_t33_650M_UR90S_3():
|
313 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
314 |
+
This is model 3 of a 5 model ensemble.
|
315 |
+
|
316 |
+
Returns a tuple of (Model, Alphabet).
|
317 |
+
"""
|
318 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3")
|
319 |
+
|
320 |
+
|
321 |
+
def esm1v_t33_650M_UR90S_4():
|
322 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
323 |
+
This is model 4 of a 5 model ensemble.
|
324 |
+
|
325 |
+
Returns a tuple of (Model, Alphabet).
|
326 |
+
"""
|
327 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4")
|
328 |
+
|
329 |
+
|
330 |
+
def esm1v_t33_650M_UR90S_5():
|
331 |
+
"""33 layer transformer model with 650M params, trained on Uniref90.
|
332 |
+
This is model 5 of a 5 model ensemble.
|
333 |
+
|
334 |
+
Returns a tuple of (Model, Alphabet).
|
335 |
+
"""
|
336 |
+
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5")
|
337 |
+
|
338 |
+
|
339 |
+
def esm_if1_gvp4_t16_142M_UR50():
|
340 |
+
"""Inverse folding model with 142M params, with 4 GVP-GNN layers, 8
|
341 |
+
Transformer encoder layers, and 8 Transformer decoder layers, trained on
|
342 |
+
CATH structures and 12 million alphafold2 predicted structures from UniRef50
|
343 |
+
sequences.
|
344 |
+
|
345 |
+
Returns a tuple of (Model, Alphabet).
|
346 |
+
"""
|
347 |
+
return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50")
|
348 |
+
|
349 |
+
|
350 |
+
def esm2_t6_8M_UR50D():
|
351 |
+
"""6 layer ESM-2 model with 8M params, trained on UniRef50.
|
352 |
+
|
353 |
+
Returns a tuple of (Model, Alphabet).
|
354 |
+
"""
|
355 |
+
return load_model_and_alphabet_hub("esm2_t6_8M_UR50D")
|
356 |
+
|
357 |
+
|
358 |
+
def esm2_t12_35M_UR50D():
|
359 |
+
"""12 layer ESM-2 model with 35M params, trained on UniRef50.
|
360 |
+
|
361 |
+
Returns a tuple of (Model, Alphabet).
|
362 |
+
"""
|
363 |
+
return load_model_and_alphabet_hub("esm2_t12_35M_UR50D")
|
364 |
+
|
365 |
+
|
366 |
+
def esm2_t30_150M_UR50D():
|
367 |
+
"""30 layer ESM-2 model with 150M params, trained on UniRef50.
|
368 |
+
|
369 |
+
Returns a tuple of (Model, Alphabet).
|
370 |
+
"""
|
371 |
+
return load_model_and_alphabet_hub("esm2_t30_150M_UR50D")
|
372 |
+
|
373 |
+
|
374 |
+
def esm2_t33_650M_UR50D():
|
375 |
+
"""33 layer ESM-2 model with 650M params, trained on UniRef50.
|
376 |
+
|
377 |
+
Returns a tuple of (Model, Alphabet).
|
378 |
+
"""
|
379 |
+
return load_model_and_alphabet_hub("esm2_t33_650M_UR50D")
|
380 |
+
|
381 |
+
|
382 |
+
def esm2_t36_3B_UR50D():
|
383 |
+
"""36 layer ESM-2 model with 3B params, trained on UniRef50.
|
384 |
+
|
385 |
+
Returns a tuple of (Model, Alphabet).
|
386 |
+
"""
|
387 |
+
return load_model_and_alphabet_hub("esm2_t36_3B_UR50D")
|
388 |
+
|
389 |
+
|
390 |
+
def esm2_t48_15B_UR50D():
|
391 |
+
"""48 layer ESM-2 model with 15B params, trained on UniRef50.
|
392 |
+
If you have OOM while loading this model, please refer to README
|
393 |
+
on how to employ FSDP and ZeRO CPU offloading
|
394 |
+
|
395 |
+
Returns a tuple of (Model, Alphabet).
|
396 |
+
"""
|
397 |
+
return load_model_and_alphabet_hub("esm2_t48_15B_UR50D")
|
398 |
+
|
399 |
+
|
400 |
+
def esmfold_v0():
|
401 |
+
"""
|
402 |
+
ESMFold v0 model with 3B ESM-2, 48 folding blocks.
|
403 |
+
This version was used for the paper (Lin et al, 2022). It was trained
|
404 |
+
on all PDB chains until 2020-05, to ensure temporal holdout with CASP14
|
405 |
+
and the CAMEO validation and test set reported there.
|
406 |
+
"""
|
407 |
+
import esm.esmfold.v1.pretrained
|
408 |
+
return esm.esmfold.v1.pretrained.esmfold_v0()
|
409 |
+
|
410 |
+
|
411 |
+
def esmfold_v1():
|
412 |
+
"""
|
413 |
+
ESMFold v1 model using 3B ESM-2, 48 folding blocks.
|
414 |
+
ESMFold provides fast high accuracy atomic level structure prediction
|
415 |
+
directly from the individual sequence of a protein. ESMFold uses the ESM2
|
416 |
+
protein language model to extract meaningful representations from the
|
417 |
+
protein sequence.
|
418 |
+
"""
|
419 |
+
import esm.esmfold.v1.pretrained
|
420 |
+
return esm.esmfold.v1.pretrained.esmfold_v1()
|
421 |
+
|
422 |
+
def esmfold_structure_module_only_8M():
|
423 |
+
"""
|
424 |
+
ESMFold baseline model using 8M ESM-2, 0 folding blocks.
|
425 |
+
ESM-2 here is trained out to 500K updates.
|
426 |
+
This is a model designed to test the capabilities of the language model
|
427 |
+
when ablated for number of parameters in the language model.
|
428 |
+
See table S1 in (Lin et al, 2022).
|
429 |
+
"""
|
430 |
+
import esm.esmfold.v1.pretrained
|
431 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_8M()
|
432 |
+
|
433 |
+
|
434 |
+
def esmfold_structure_module_only_8M_270K():
|
435 |
+
"""
|
436 |
+
ESMFold baseline model using 8M ESM-2, 0 folding blocks.
|
437 |
+
ESM-2 here is trained out to 270K updates.
|
438 |
+
This is a model designed to test the capabilities of the language model
|
439 |
+
when ablated for number of parameters in the language model.
|
440 |
+
See table S1 in (Lin et al, 2022).
|
441 |
+
"""
|
442 |
+
import esm.esmfold.v1.pretrained
|
443 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_8M_270K()
|
444 |
+
|
445 |
+
|
446 |
+
def esmfold_structure_module_only_35M():
|
447 |
+
"""
|
448 |
+
ESMFold baseline model using 35M ESM-2, 0 folding blocks.
|
449 |
+
ESM-2 here is trained out to 500K updates.
|
450 |
+
This is a model designed to test the capabilities of the language model
|
451 |
+
when ablated for number of parameters in the language model.
|
452 |
+
See table S1 in (Lin et al, 2022).
|
453 |
+
"""
|
454 |
+
import esm.esmfold.v1.pretrained
|
455 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_35M()
|
456 |
+
|
457 |
+
|
458 |
+
def esmfold_structure_module_only_35M_270K():
|
459 |
+
"""
|
460 |
+
ESMFold baseline model using 35M ESM-2, 0 folding blocks.
|
461 |
+
ESM-2 here is trained out to 270K updates.
|
462 |
+
This is a model designed to test the capabilities of the language model
|
463 |
+
when ablated for number of parameters in the language model.
|
464 |
+
See table S1 in (Lin et al, 2022).
|
465 |
+
"""
|
466 |
+
import esm.esmfold.v1.pretrained
|
467 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_35M_270K()
|
468 |
+
|
469 |
+
|
470 |
+
def esmfold_structure_module_only_150M():
|
471 |
+
"""
|
472 |
+
ESMFold baseline model using 150M ESM-2, 0 folding blocks.
|
473 |
+
ESM-2 here is trained out to 500K updates.
|
474 |
+
This is a model designed to test the capabilities of the language model
|
475 |
+
when ablated for number of parameters in the language model.
|
476 |
+
See table S1 in (Lin et al, 2022).
|
477 |
+
"""
|
478 |
+
import esm.esmfold.v1.pretrained
|
479 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_150M()
|
480 |
+
|
481 |
+
|
482 |
+
def esmfold_structure_module_only_150M_270K():
|
483 |
+
"""
|
484 |
+
ESMFold baseline model using 150M ESM-2, 0 folding blocks.
|
485 |
+
ESM-2 here is trained out to 270K updates.
|
486 |
+
This is a model designed to test the capabilities of the language model
|
487 |
+
when ablated for number of parameters in the language model.
|
488 |
+
See table S1 in (Lin et al, 2022).
|
489 |
+
"""
|
490 |
+
import esm.esmfold.v1.pretrained
|
491 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_150M_270K()
|
492 |
+
|
493 |
+
|
494 |
+
def esmfold_structure_module_only_650M():
|
495 |
+
"""
|
496 |
+
ESMFold baseline model using 650M ESM-2, 0 folding blocks.
|
497 |
+
ESM-2 here is trained out to 500K updates.
|
498 |
+
This is a model designed to test the capabilities of the language model
|
499 |
+
when ablated for number of parameters in the language model.
|
500 |
+
See table S1 in (Lin et al, 2022).
|
501 |
+
"""
|
502 |
+
import esm.esmfold.v1.pretrained
|
503 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_650M()
|
504 |
+
|
505 |
+
|
506 |
+
def esmfold_structure_module_only_650M_270K():
|
507 |
+
"""
|
508 |
+
ESMFold baseline model using 650M ESM-2, 0 folding blocks.
|
509 |
+
ESM-2 here is trained out to 270K updates.
|
510 |
+
This is a model designed to test the capabilities of the language model
|
511 |
+
when ablated for number of parameters in the language model.
|
512 |
+
See table S1 in (Lin et al, 2022).
|
513 |
+
"""
|
514 |
+
import esm.esmfold.v1.pretrained
|
515 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_650M_270K()
|
516 |
+
|
517 |
+
|
518 |
+
def esmfold_structure_module_only_3B():
|
519 |
+
"""
|
520 |
+
ESMFold baseline model using 3B ESM-2, 0 folding blocks.
|
521 |
+
ESM-2 here is trained out to 500K updates.
|
522 |
+
This is a model designed to test the capabilities of the language model
|
523 |
+
when ablated for number of parameters in the language model.
|
524 |
+
See table S1 in (Lin et al, 2022).
|
525 |
+
"""
|
526 |
+
import esm.esmfold.v1.pretrained
|
527 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_3B()
|
528 |
+
|
529 |
+
|
530 |
+
def esmfold_structure_module_only_3B_270K():
|
531 |
+
"""
|
532 |
+
ESMFold baseline model using 3B ESM-2, 0 folding blocks.
|
533 |
+
ESM-2 here is trained out to 270K updates.
|
534 |
+
This is a model designed to test the capabilities of the language model
|
535 |
+
when ablated for number of parameters in the language model.
|
536 |
+
See table S1 in (Lin et al, 2022).
|
537 |
+
"""
|
538 |
+
import esm.esmfold.v1.pretrained
|
539 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_3B_270K()
|
540 |
+
|
541 |
+
|
542 |
+
def esmfold_structure_module_only_15B():
|
543 |
+
"""
|
544 |
+
ESMFold baseline model using 15B ESM-2, 0 folding blocks.
|
545 |
+
ESM-2 here is trained out to 270K updates.
|
546 |
+
The 15B parameter ESM-2 was not trained out to 500K updates
|
547 |
+
This is a model designed to test the capabilities of the language model
|
548 |
+
when ablated for number of parameters in the language model.
|
549 |
+
See table S1 in (Lin et al, 2022).
|
550 |
+
"""
|
551 |
+
import esm.esmfold.v1.pretrained
|
552 |
+
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_15B()
|
esm/rotary_embedding.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def rotate_half(x):
|
12 |
+
x1, x2 = x.chunk(2, dim=-1)
|
13 |
+
return torch.cat((-x2, x1), dim=-1)
|
14 |
+
|
15 |
+
|
16 |
+
def apply_rotary_pos_emb(x, cos, sin):
|
17 |
+
cos = cos[:, : x.shape[-2], :]
|
18 |
+
sin = sin[:, : x.shape[-2], :]
|
19 |
+
|
20 |
+
return (x * cos) + (rotate_half(x) * sin)
|
21 |
+
|
22 |
+
|
23 |
+
class RotaryEmbedding(torch.nn.Module):
|
24 |
+
"""
|
25 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
26 |
+
A crucial insight from the method is that the query and keys are
|
27 |
+
transformed by rotation matrices which depend on the relative positions.
|
28 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
29 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
30 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
31 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
32 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
33 |
+
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
34 |
+
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, dim: int, *_, **__):
|
38 |
+
super().__init__()
|
39 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
40 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
41 |
+
self.register_buffer("inv_freq", inv_freq)
|
42 |
+
|
43 |
+
self._seq_len_cached = None
|
44 |
+
self._cos_cached = None
|
45 |
+
self._sin_cached = None
|
46 |
+
|
47 |
+
def _update_cos_sin_tables(self, x, seq_dimension=1):
|
48 |
+
seq_len = x.shape[seq_dimension]
|
49 |
+
|
50 |
+
# Reset the tables if the sequence length has changed,
|
51 |
+
# or if we're on a new device (possibly due to tracing for instance)
|
52 |
+
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
|
53 |
+
self._seq_len_cached = seq_len
|
54 |
+
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
|
55 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
56 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
57 |
+
|
58 |
+
self._cos_cached = emb.cos()[None, :, :]
|
59 |
+
self._sin_cached = emb.sin()[None, :, :]
|
60 |
+
|
61 |
+
return self._cos_cached, self._sin_cached
|
62 |
+
|
63 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
64 |
+
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
|
65 |
+
|
66 |
+
return (
|
67 |
+
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
68 |
+
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
69 |
+
)
|
esm/version.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
version = "2.0.1"
|