File size: 5,132 Bytes
3f0529e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from scipy.spatial import transform
from esm.data import Alphabet
from .features import DihedralFeatures
from .gvp_encoder import GVPEncoder
from .gvp_utils import unflatten_graph
from .gvp_transformer_encoder import GVPTransformerEncoder
from .transformer_decoder import TransformerDecoder
from .util import rotate, CoordBatchConverter
class GVPTransformerModel(nn.Module):
"""
GVP-Transformer inverse folding model.
Architecture: Geometric GVP-GNN as initial layers, followed by
sequence-to-sequence Transformer encoder and decoder.
"""
def __init__(self, args, alphabet):
super().__init__()
encoder_embed_tokens = self.build_embedding(
args, alphabet, args.encoder_embed_dim,
)
decoder_embed_tokens = self.build_embedding(
args, alphabet, args.decoder_embed_dim,
)
encoder = self.build_encoder(args, alphabet, encoder_embed_tokens)
decoder = self.build_decoder(args, alphabet, decoder_embed_tokens)
self.args = args
self.encoder = encoder
self.decoder = decoder
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
encoder = GVPTransformerEncoder(args, src_dict, embed_tokens)
return encoder
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = TransformerDecoder(
args,
tgt_dict,
embed_tokens,
)
return decoder
@classmethod
def build_embedding(cls, args, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.padding_idx
emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5)
nn.init.constant_(emb.weight[padding_idx], 0)
return emb
def forward(
self,
coords,
padding_mask,
confidence,
prev_output_tokens,
return_all_hiddens: bool = False,
features_only: bool = False,
):
encoder_out = self.encoder(coords, padding_mask, confidence,
return_all_hiddens=return_all_hiddens)
logits, extra = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
return_all_hiddens=return_all_hiddens,
)
return logits, extra
def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None):
"""
Samples sequences based on multinomial sampling (no beam search).
Args:
coords: L x 3 x 3 list representing one backbone
partial_seq: Optional, partial sequence with mask tokens if part of
the sequence is known
temperature: sampling temperature, use low temperature for higher
sequence recovery and high temperature for higher diversity
confidence: optional length L list of confidence scores for coordinates
"""
L = len(coords)
# Convert to batch format
batch_converter = CoordBatchConverter(self.decoder.dictionary)
batch_coords, confidence, _, _, padding_mask = (
batch_converter([(coords, confidence, None)], device=device)
)
# Start with prepend token
mask_idx = self.decoder.dictionary.get_idx('<mask>')
sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int)
sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>')
if partial_seq is not None:
for i, c in enumerate(partial_seq):
sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c)
# Save incremental states for faster sampling
incremental_state = dict()
# Run encoder only once
encoder_out = self.encoder(batch_coords, padding_mask, confidence)
# Make sure all tensors are on the same device if a GPU is present
if device:
sampled_tokens = sampled_tokens.to(device)
# Decode one token at a time
for i in range(1, L+1):
logits, _ = self.decoder(
sampled_tokens[:, :i],
encoder_out,
incremental_state=incremental_state,
)
logits = logits[0].transpose(0, 1)
logits /= temperature
probs = F.softmax(logits, dim=-1)
if sampled_tokens[0, i] == mask_idx:
sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
sampled_seq = sampled_tokens[0, 1:]
# Convert back to string via lookup
return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq])
|