FAPM_demo / esm /esmfold /v1 /tri_self_attn_block.py
wenkai's picture
Upload 31 files
3f0529e verified
raw
history blame
No virus
6.2 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from openfold.model.triangular_attention import (
TriangleAttentionEndingNode,
TriangleAttentionStartingNode,
)
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming,
TriangleMultiplicationOutgoing,
)
from torch import nn
from esm.esmfold.v1.misc import (
Attention,
Dropout,
PairToSequence,
ResidueMLP,
SequenceToPair,
)
class TriangularSelfAttentionBlock(nn.Module):
def __init__(
self,
sequence_state_dim,
pairwise_state_dim,
sequence_head_width,
pairwise_head_width,
dropout=0,
**__kwargs,
):
super().__init__()
assert sequence_state_dim % sequence_head_width == 0
assert pairwise_state_dim % pairwise_head_width == 0
sequence_num_heads = sequence_state_dim // sequence_head_width
pairwise_num_heads = pairwise_state_dim // pairwise_head_width
assert sequence_state_dim == sequence_num_heads * sequence_head_width
assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width
assert pairwise_state_dim % 2 == 0
self.sequence_state_dim = sequence_state_dim
self.pairwise_state_dim = pairwise_state_dim
self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
self.sequence_to_pair = SequenceToPair(
sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim
)
self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads)
self.seq_attention = Attention(
sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
pairwise_state_dim,
pairwise_state_dim,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
pairwise_state_dim,
pairwise_state_dim,
)
self.tri_att_start = TriangleAttentionStartingNode(
pairwise_state_dim,
pairwise_head_width,
pairwise_num_heads,
inf=1e9,
) # type: ignore
self.tri_att_end = TriangleAttentionEndingNode(
pairwise_state_dim,
pairwise_head_width,
pairwise_num_heads,
inf=1e9,
) # type: ignore
self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout)
self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout)
assert dropout < 0.4
self.drop = nn.Dropout(dropout)
self.row_drop = Dropout(dropout * 2, 2)
self.col_drop = Dropout(dropout * 2, 1)
torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight)
torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias)
torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight)
torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias)
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight)
torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias)
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight)
torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias)
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight)
torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias)
torch.nn.init.zeros_(self.pair_to_sequence.linear.weight)
torch.nn.init.zeros_(self.seq_attention.o_proj.weight)
torch.nn.init.zeros_(self.seq_attention.o_proj.bias)
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight)
torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias)
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight)
torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias)
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
"""
Inputs:
sequence_state: B x L x sequence_state_dim
pairwise_state: B x L x L x pairwise_state_dim
mask: B x L boolean tensor of valid positions
Output:
sequence_state: B x L x sequence_state_dim
pairwise_state: B x L x L x pairwise_state_dim
"""
assert len(sequence_state.shape) == 3
assert len(pairwise_state.shape) == 4
if mask is not None:
assert len(mask.shape) == 2
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
pairwise_state_dim = pairwise_state.shape[3]
assert sequence_state_dim == self.sequence_state_dim
assert pairwise_state_dim == self.pairwise_state_dim
assert batch_dim == pairwise_state.shape[0]
assert seq_dim == pairwise_state.shape[1]
assert seq_dim == pairwise_state.shape[2]
# Update sequence state
bias = self.pair_to_sequence(pairwise_state)
# Self attention with bias + mlp.
y = self.layernorm_1(sequence_state)
y, _ = self.seq_attention(y, mask=mask, bias=bias)
sequence_state = sequence_state + self.drop(y)
sequence_state = self.mlp_seq(sequence_state)
# Update pairwise state
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
# Axial attention with triangular bias.
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
pairwise_state = pairwise_state + self.row_drop(
self.tri_mul_out(pairwise_state, mask=tri_mask)
)
pairwise_state = pairwise_state + self.col_drop(
self.tri_mul_in(pairwise_state, mask=tri_mask)
)
pairwise_state = pairwise_state + self.row_drop(
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
)
pairwise_state = pairwise_state + self.col_drop(
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
)
# MLP over pairs.
pairwise_state = self.mlp_pair(pairwise_state)
return sequence_state, pairwise_state