|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from typing import Any, Dict, List, Optional, Tuple, NamedTuple |
|
import torch |
|
from torch import nn |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from scipy.spatial import transform |
|
|
|
from esm.data import Alphabet |
|
|
|
from .features import DihedralFeatures |
|
from .gvp_encoder import GVPEncoder |
|
from .gvp_utils import unflatten_graph |
|
from .gvp_transformer_encoder import GVPTransformerEncoder |
|
from .transformer_decoder import TransformerDecoder |
|
from .util import rotate, CoordBatchConverter |
|
|
|
|
|
class GVPTransformerModel(nn.Module): |
|
""" |
|
GVP-Transformer inverse folding model. |
|
|
|
Architecture: Geometric GVP-GNN as initial layers, followed by |
|
sequence-to-sequence Transformer encoder and decoder. |
|
""" |
|
|
|
def __init__(self, args, alphabet): |
|
super().__init__() |
|
encoder_embed_tokens = self.build_embedding( |
|
args, alphabet, args.encoder_embed_dim, |
|
) |
|
decoder_embed_tokens = self.build_embedding( |
|
args, alphabet, args.decoder_embed_dim, |
|
) |
|
encoder = self.build_encoder(args, alphabet, encoder_embed_tokens) |
|
decoder = self.build_decoder(args, alphabet, decoder_embed_tokens) |
|
self.args = args |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
|
|
@classmethod |
|
def build_encoder(cls, args, src_dict, embed_tokens): |
|
encoder = GVPTransformerEncoder(args, src_dict, embed_tokens) |
|
return encoder |
|
|
|
@classmethod |
|
def build_decoder(cls, args, tgt_dict, embed_tokens): |
|
decoder = TransformerDecoder( |
|
args, |
|
tgt_dict, |
|
embed_tokens, |
|
) |
|
return decoder |
|
|
|
@classmethod |
|
def build_embedding(cls, args, dictionary, embed_dim): |
|
num_embeddings = len(dictionary) |
|
padding_idx = dictionary.padding_idx |
|
emb = nn.Embedding(num_embeddings, embed_dim, padding_idx) |
|
nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5) |
|
nn.init.constant_(emb.weight[padding_idx], 0) |
|
return emb |
|
|
|
def forward( |
|
self, |
|
coords, |
|
padding_mask, |
|
confidence, |
|
prev_output_tokens, |
|
return_all_hiddens: bool = False, |
|
features_only: bool = False, |
|
): |
|
encoder_out = self.encoder(coords, padding_mask, confidence, |
|
return_all_hiddens=return_all_hiddens) |
|
logits, extra = self.decoder( |
|
prev_output_tokens, |
|
encoder_out=encoder_out, |
|
features_only=features_only, |
|
return_all_hiddens=return_all_hiddens, |
|
) |
|
return logits, extra |
|
|
|
def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None): |
|
""" |
|
Samples sequences based on multinomial sampling (no beam search). |
|
|
|
Args: |
|
coords: L x 3 x 3 list representing one backbone |
|
partial_seq: Optional, partial sequence with mask tokens if part of |
|
the sequence is known |
|
temperature: sampling temperature, use low temperature for higher |
|
sequence recovery and high temperature for higher diversity |
|
confidence: optional length L list of confidence scores for coordinates |
|
""" |
|
L = len(coords) |
|
|
|
batch_converter = CoordBatchConverter(self.decoder.dictionary) |
|
batch_coords, confidence, _, _, padding_mask = ( |
|
batch_converter([(coords, confidence, None)], device=device) |
|
) |
|
|
|
|
|
mask_idx = self.decoder.dictionary.get_idx('<mask>') |
|
sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int) |
|
sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>') |
|
if partial_seq is not None: |
|
for i, c in enumerate(partial_seq): |
|
sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c) |
|
|
|
|
|
incremental_state = dict() |
|
|
|
|
|
encoder_out = self.encoder(batch_coords, padding_mask, confidence) |
|
|
|
|
|
if device: |
|
sampled_tokens = sampled_tokens.to(device) |
|
|
|
|
|
for i in range(1, L+1): |
|
logits, _ = self.decoder( |
|
sampled_tokens[:, :i], |
|
encoder_out, |
|
incremental_state=incremental_state, |
|
) |
|
logits = logits[0].transpose(0, 1) |
|
logits /= temperature |
|
probs = F.softmax(logits, dim=-1) |
|
if sampled_tokens[0, i] == mask_idx: |
|
sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1) |
|
sampled_seq = sampled_tokens[0, 1:] |
|
|
|
|
|
return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq]) |
|
|