|
|
|
|
|
|
|
|
|
import torch |
|
from openfold.model.triangular_attention import ( |
|
TriangleAttentionEndingNode, |
|
TriangleAttentionStartingNode, |
|
) |
|
from openfold.model.triangular_multiplicative_update import ( |
|
TriangleMultiplicationIncoming, |
|
TriangleMultiplicationOutgoing, |
|
) |
|
from torch import nn |
|
|
|
from esm.esmfold.v1.misc import ( |
|
Attention, |
|
Dropout, |
|
PairToSequence, |
|
ResidueMLP, |
|
SequenceToPair, |
|
) |
|
|
|
|
|
class TriangularSelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
sequence_state_dim, |
|
pairwise_state_dim, |
|
sequence_head_width, |
|
pairwise_head_width, |
|
dropout=0, |
|
**__kwargs, |
|
): |
|
super().__init__() |
|
|
|
assert sequence_state_dim % sequence_head_width == 0 |
|
assert pairwise_state_dim % pairwise_head_width == 0 |
|
sequence_num_heads = sequence_state_dim // sequence_head_width |
|
pairwise_num_heads = pairwise_state_dim // pairwise_head_width |
|
assert sequence_state_dim == sequence_num_heads * sequence_head_width |
|
assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width |
|
assert pairwise_state_dim % 2 == 0 |
|
|
|
self.sequence_state_dim = sequence_state_dim |
|
self.pairwise_state_dim = pairwise_state_dim |
|
|
|
self.layernorm_1 = nn.LayerNorm(sequence_state_dim) |
|
|
|
self.sequence_to_pair = SequenceToPair( |
|
sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim |
|
) |
|
self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads) |
|
|
|
self.seq_attention = Attention( |
|
sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True |
|
) |
|
self.tri_mul_out = TriangleMultiplicationOutgoing( |
|
pairwise_state_dim, |
|
pairwise_state_dim, |
|
) |
|
self.tri_mul_in = TriangleMultiplicationIncoming( |
|
pairwise_state_dim, |
|
pairwise_state_dim, |
|
) |
|
self.tri_att_start = TriangleAttentionStartingNode( |
|
pairwise_state_dim, |
|
pairwise_head_width, |
|
pairwise_num_heads, |
|
inf=1e9, |
|
) |
|
self.tri_att_end = TriangleAttentionEndingNode( |
|
pairwise_state_dim, |
|
pairwise_head_width, |
|
pairwise_num_heads, |
|
inf=1e9, |
|
) |
|
|
|
self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) |
|
self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) |
|
|
|
assert dropout < 0.4 |
|
self.drop = nn.Dropout(dropout) |
|
self.row_drop = Dropout(dropout * 2, 2) |
|
self.col_drop = Dropout(dropout * 2, 1) |
|
|
|
torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) |
|
torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) |
|
torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) |
|
torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) |
|
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight) |
|
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias) |
|
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight) |
|
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias) |
|
|
|
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) |
|
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) |
|
torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) |
|
torch.nn.init.zeros_(self.seq_attention.o_proj.weight) |
|
torch.nn.init.zeros_(self.seq_attention.o_proj.bias) |
|
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) |
|
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) |
|
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) |
|
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) |
|
|
|
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): |
|
""" |
|
Inputs: |
|
sequence_state: B x L x sequence_state_dim |
|
pairwise_state: B x L x L x pairwise_state_dim |
|
mask: B x L boolean tensor of valid positions |
|
|
|
Output: |
|
sequence_state: B x L x sequence_state_dim |
|
pairwise_state: B x L x L x pairwise_state_dim |
|
""" |
|
assert len(sequence_state.shape) == 3 |
|
assert len(pairwise_state.shape) == 4 |
|
if mask is not None: |
|
assert len(mask.shape) == 2 |
|
|
|
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape |
|
pairwise_state_dim = pairwise_state.shape[3] |
|
assert sequence_state_dim == self.sequence_state_dim |
|
assert pairwise_state_dim == self.pairwise_state_dim |
|
assert batch_dim == pairwise_state.shape[0] |
|
assert seq_dim == pairwise_state.shape[1] |
|
assert seq_dim == pairwise_state.shape[2] |
|
|
|
|
|
bias = self.pair_to_sequence(pairwise_state) |
|
|
|
|
|
y = self.layernorm_1(sequence_state) |
|
y, _ = self.seq_attention(y, mask=mask, bias=bias) |
|
sequence_state = sequence_state + self.drop(y) |
|
sequence_state = self.mlp_seq(sequence_state) |
|
|
|
|
|
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) |
|
|
|
|
|
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None |
|
pairwise_state = pairwise_state + self.row_drop( |
|
self.tri_mul_out(pairwise_state, mask=tri_mask) |
|
) |
|
pairwise_state = pairwise_state + self.col_drop( |
|
self.tri_mul_in(pairwise_state, mask=tri_mask) |
|
) |
|
pairwise_state = pairwise_state + self.row_drop( |
|
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) |
|
) |
|
pairwise_state = pairwise_state + self.col_drop( |
|
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) |
|
) |
|
|
|
|
|
pairwise_state = self.mlp_pair(pairwise_state) |
|
|
|
return sequence_state, pairwise_state |
|
|