# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from openfold.model.triangular_attention import ( TriangleAttentionEndingNode, TriangleAttentionStartingNode, ) from openfold.model.triangular_multiplicative_update import ( TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing, ) from torch import nn from esm.esmfold.v1.misc import ( Attention, Dropout, PairToSequence, ResidueMLP, SequenceToPair, ) class TriangularSelfAttentionBlock(nn.Module): def __init__( self, sequence_state_dim, pairwise_state_dim, sequence_head_width, pairwise_head_width, dropout=0, **__kwargs, ): super().__init__() assert sequence_state_dim % sequence_head_width == 0 assert pairwise_state_dim % pairwise_head_width == 0 sequence_num_heads = sequence_state_dim // sequence_head_width pairwise_num_heads = pairwise_state_dim // pairwise_head_width assert sequence_state_dim == sequence_num_heads * sequence_head_width assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width assert pairwise_state_dim % 2 == 0 self.sequence_state_dim = sequence_state_dim self.pairwise_state_dim = pairwise_state_dim self.layernorm_1 = nn.LayerNorm(sequence_state_dim) self.sequence_to_pair = SequenceToPair( sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim ) self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads) self.seq_attention = Attention( sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True ) self.tri_mul_out = TriangleMultiplicationOutgoing( pairwise_state_dim, pairwise_state_dim, ) self.tri_mul_in = TriangleMultiplicationIncoming( pairwise_state_dim, pairwise_state_dim, ) self.tri_att_start = TriangleAttentionStartingNode( pairwise_state_dim, pairwise_head_width, pairwise_num_heads, inf=1e9, ) # type: ignore self.tri_att_end = TriangleAttentionEndingNode( pairwise_state_dim, pairwise_head_width, pairwise_num_heads, inf=1e9, ) # type: ignore self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) assert dropout < 0.4 self.drop = nn.Dropout(dropout) self.row_drop = Dropout(dropout * 2, 2) self.col_drop = Dropout(dropout * 2, 1) torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight) torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias) torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight) torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias) torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) torch.nn.init.zeros_(self.seq_attention.o_proj.weight) torch.nn.init.zeros_(self.seq_attention.o_proj.bias) torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): """ Inputs: sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean tensor of valid positions Output: sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim """ assert len(sequence_state.shape) == 3 assert len(pairwise_state.shape) == 4 if mask is not None: assert len(mask.shape) == 2 batch_dim, seq_dim, sequence_state_dim = sequence_state.shape pairwise_state_dim = pairwise_state.shape[3] assert sequence_state_dim == self.sequence_state_dim assert pairwise_state_dim == self.pairwise_state_dim assert batch_dim == pairwise_state.shape[0] assert seq_dim == pairwise_state.shape[1] assert seq_dim == pairwise_state.shape[2] # Update sequence state bias = self.pair_to_sequence(pairwise_state) # Self attention with bias + mlp. y = self.layernorm_1(sequence_state) y, _ = self.seq_attention(y, mask=mask, bias=bias) sequence_state = sequence_state + self.drop(y) sequence_state = self.mlp_seq(sequence_state) # Update pairwise state pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) # Axial attention with triangular bias. tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None pairwise_state = pairwise_state + self.row_drop( self.tri_mul_out(pairwise_state, mask=tri_mask) ) pairwise_state = pairwise_state + self.col_drop( self.tri_mul_in(pairwise_state, mask=tri_mask) ) pairwise_state = pairwise_state + self.row_drop( self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) ) pairwise_state = pairwise_state + self.col_drop( self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) ) # MLP over pairs. pairwise_state = self.mlp_pair(pairwise_state) return sequence_state, pairwise_state