# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # import argparse # from typing import Any, Dict, List, Optional, Tuple, NamedTuple import torch from torch import nn # from torch import Tensor import torch.nn.functional as F # from scipy.spatial import transform # # from esm.data import Alphabet # from .features import DihedralFeatures # from .gvp_encoder import GVPEncoder # from .gvp_utils import unflatten_graph print("gvp1_transformer") from .gvp_transformer_encoder import GVPTransformerEncoder print("gvp2_transformer") from .transformer_decoder import TransformerDecoder print("gvp3_transformer") from .util import rotate, CoordBatchConverter print("gvp4_transformer") class GVPTransformerModel(nn.Module): """ GVP-Transformer inverse folding model. Architecture: Geometric GVP-GNN as initial layers, followed by sequence-to-sequence Transformer encoder and decoder. """ def __init__(self, args, alphabet): super().__init__() encoder_embed_tokens = self.build_embedding( args, alphabet, args.encoder_embed_dim, ) decoder_embed_tokens = self.build_embedding( args, alphabet, args.decoder_embed_dim, ) encoder = self.build_encoder(args, alphabet, encoder_embed_tokens) decoder = self.build_decoder(args, alphabet, decoder_embed_tokens) self.args = args self.encoder = encoder self.decoder = decoder @classmethod def build_encoder(cls, args, src_dict, embed_tokens): encoder = GVPTransformerEncoder(args, src_dict, embed_tokens) return encoder @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): decoder = TransformerDecoder( args, tgt_dict, embed_tokens, ) return decoder @classmethod def build_embedding(cls, args, dictionary, embed_dim): num_embeddings = len(dictionary) padding_idx = dictionary.padding_idx emb = nn.Embedding(num_embeddings, embed_dim, padding_idx) nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5) nn.init.constant_(emb.weight[padding_idx], 0) return emb def forward( self, coords, padding_mask, confidence, prev_output_tokens, return_all_hiddens: bool = False, features_only: bool = False, ): encoder_out = self.encoder(coords, padding_mask, confidence, return_all_hiddens=return_all_hiddens) logits, extra = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, return_all_hiddens=return_all_hiddens, ) return logits, extra def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None): """ Samples sequences based on multinomial sampling (no beam search). Args: coords: L x 3 x 3 list representing one backbone partial_seq: Optional, partial sequence with mask tokens if part of the sequence is known temperature: sampling temperature, use low temperature for higher sequence recovery and high temperature for higher diversity confidence: optional length L list of confidence scores for coordinates """ L = len(coords) # Convert to batch format batch_converter = CoordBatchConverter(self.decoder.dictionary) batch_coords, confidence, _, _, padding_mask = ( batch_converter([(coords, confidence, None)], device=device) ) # Start with prepend token mask_idx = self.decoder.dictionary.get_idx('') sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int) sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('') if partial_seq is not None: for i, c in enumerate(partial_seq): sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c) # Save incremental states for faster sampling incremental_state = dict() # Run encoder only once encoder_out = self.encoder(batch_coords, padding_mask, confidence) # Make sure all tensors are on the same device if a GPU is present if device: sampled_tokens = sampled_tokens.to(device) # Decode one token at a time for i in range(1, L+1): logits, _ = self.decoder( sampled_tokens[:, :i], encoder_out, incremental_state=incremental_state, ) logits = logits[0].transpose(0, 1) logits /= temperature probs = F.softmax(logits, dim=-1) if sampled_tokens[0, i] == mask_idx: sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1) sampled_seq = sampled_tokens[0, 1:] # Convert back to string via lookup return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq]), encoder_out