P2DFlow / data /repr.py
Holmes
test
ca7299e
import torch
from opt_einsum import contract as einsum
import esm
from data.residue_constants import order2restype_with_mask
def get_pre_repr(seqs, model, alphabet, batch_converter, device="cuda:0"):
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
# data = [
# ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
# ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
# ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
# ("protein3", "K A <mask> I S Q"),
# ]
data = []
for idx, seq in enumerate([seqs]):
seq_string = ''.join([order2restype_with_mask[int(i)] for i in seq])
data.append(("protein_"+str(idx), seq_string))
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens.to(device), repr_layers=[33], return_contacts=True)
node_repr = results["representations"][33][:,1:-1,:]
pair_repr = results['attentions'][:,33-1,:,1:-1,1:-1].permute(0,2,3,1)
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
# sequence_representations = []
# for i, tokens_len in enumerate(batch_lens):
# sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
# Look at the unsupervised self-attention map contact predictions
# for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
# plt.matshow(attention_contacts[: tokens_len, : tokens_len])
return node_repr, pair_repr