Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from openfold.model.triangular_attention import ( | |
TriangleAttentionEndingNode, | |
TriangleAttentionStartingNode, | |
) | |
from openfold.model.triangular_multiplicative_update import ( | |
TriangleMultiplicationIncoming, | |
TriangleMultiplicationOutgoing, | |
) | |
from torch import nn | |
from esm.esmfold.v1.misc import ( | |
Attention, | |
Dropout, | |
PairToSequence, | |
ResidueMLP, | |
SequenceToPair, | |
) | |
class TriangularSelfAttentionBlock(nn.Module): | |
def __init__( | |
self, | |
sequence_state_dim, | |
pairwise_state_dim, | |
sequence_head_width, | |
pairwise_head_width, | |
dropout=0, | |
**__kwargs, | |
): | |
super().__init__() | |
assert sequence_state_dim % sequence_head_width == 0 | |
assert pairwise_state_dim % pairwise_head_width == 0 | |
sequence_num_heads = sequence_state_dim // sequence_head_width | |
pairwise_num_heads = pairwise_state_dim // pairwise_head_width | |
assert sequence_state_dim == sequence_num_heads * sequence_head_width | |
assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width | |
assert pairwise_state_dim % 2 == 0 | |
self.sequence_state_dim = sequence_state_dim | |
self.pairwise_state_dim = pairwise_state_dim | |
self.layernorm_1 = nn.LayerNorm(sequence_state_dim) | |
self.sequence_to_pair = SequenceToPair( | |
sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim | |
) | |
self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads) | |
self.seq_attention = Attention( | |
sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True | |
) | |
self.tri_mul_out = TriangleMultiplicationOutgoing( | |
pairwise_state_dim, | |
pairwise_state_dim, | |
) | |
self.tri_mul_in = TriangleMultiplicationIncoming( | |
pairwise_state_dim, | |
pairwise_state_dim, | |
) | |
self.tri_att_start = TriangleAttentionStartingNode( | |
pairwise_state_dim, | |
pairwise_head_width, | |
pairwise_num_heads, | |
inf=1e9, | |
) # type: ignore | |
self.tri_att_end = TriangleAttentionEndingNode( | |
pairwise_state_dim, | |
pairwise_head_width, | |
pairwise_num_heads, | |
inf=1e9, | |
) # type: ignore | |
self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) | |
self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) | |
assert dropout < 0.4 | |
self.drop = nn.Dropout(dropout) | |
self.row_drop = Dropout(dropout * 2, 2) | |
self.col_drop = Dropout(dropout * 2, 1) | |
torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) | |
torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) | |
torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) | |
torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) | |
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight) | |
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias) | |
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight) | |
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias) | |
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) | |
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) | |
torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) | |
torch.nn.init.zeros_(self.seq_attention.o_proj.weight) | |
torch.nn.init.zeros_(self.seq_attention.o_proj.bias) | |
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) | |
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) | |
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) | |
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) | |
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): | |
""" | |
Inputs: | |
sequence_state: B x L x sequence_state_dim | |
pairwise_state: B x L x L x pairwise_state_dim | |
mask: B x L boolean tensor of valid positions | |
Output: | |
sequence_state: B x L x sequence_state_dim | |
pairwise_state: B x L x L x pairwise_state_dim | |
""" | |
assert len(sequence_state.shape) == 3 | |
assert len(pairwise_state.shape) == 4 | |
if mask is not None: | |
assert len(mask.shape) == 2 | |
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape | |
pairwise_state_dim = pairwise_state.shape[3] | |
assert sequence_state_dim == self.sequence_state_dim | |
assert pairwise_state_dim == self.pairwise_state_dim | |
assert batch_dim == pairwise_state.shape[0] | |
assert seq_dim == pairwise_state.shape[1] | |
assert seq_dim == pairwise_state.shape[2] | |
# Update sequence state | |
bias = self.pair_to_sequence(pairwise_state) | |
# Self attention with bias + mlp. | |
y = self.layernorm_1(sequence_state) | |
y, _ = self.seq_attention(y, mask=mask, bias=bias) | |
sequence_state = sequence_state + self.drop(y) | |
sequence_state = self.mlp_seq(sequence_state) | |
# Update pairwise state | |
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) | |
# Axial attention with triangular bias. | |
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None | |
pairwise_state = pairwise_state + self.row_drop( | |
self.tri_mul_out(pairwise_state, mask=tri_mask) | |
) | |
pairwise_state = pairwise_state + self.col_drop( | |
self.tri_mul_in(pairwise_state, mask=tri_mask) | |
) | |
pairwise_state = pairwise_state + self.row_drop( | |
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) | |
) | |
pairwise_state = pairwise_state + self.col_drop( | |
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) | |
) | |
# MLP over pairs. | |
pairwise_state = self.mlp_pair(pairwise_state) | |
return sequence_state, pairwise_state | |