|
import esm |
|
import numpy as np |
|
from utils import tokenize_sequences, rna2vec |
|
import torch |
|
import pickle |
|
import random |
|
|
|
class SimAnnealer: |
|
def __init__(self, temperature, model, steps, target, length): |
|
self.temp = temperature |
|
self.device = 'cpu' |
|
self.model = model.to(self.device) |
|
self.length = length |
|
self.initial_state = self.random_state_generator() |
|
self.alpha = 0.95 |
|
self.steps = steps |
|
self.target = target |
|
self.n_prot_vocabs = 1 + 713 + 1 |
|
self.n_prot_target_vocabs = 1 + 584 |
|
self.prot_max_len = 867 |
|
with open('./data/protein_word_freq_3.pickle', 'rb') as fr: |
|
words = pickle.load(fr) |
|
words = words[words["freq"]>words.freq.mean()].seq.values |
|
self.prot_words = {word:i+1 for i, word in enumerate(words)} |
|
|
|
def simulation(self): |
|
state = self.initial_state |
|
tokenized_target = self.prot_tokenizer().to(self.device) |
|
self.model.eval() |
|
for i in range(self.steps): |
|
if i % 10 == 0: |
|
print(f'Running simulation at step {i}') |
|
tokenized_state = self.apta_tokenizer(state).to(self.device) |
|
self.temp = self.temp_scheduler(i) |
|
neighbor = self.mutate(state) |
|
tokenized_neighbor = self.apta_tokenizer(neighbor).to(self.device) |
|
with torch.no_grad(): |
|
d1 = self.model(tokenized_state, tokenized_target) |
|
d2 = self.model(tokenized_neighbor, tokenized_target) |
|
if d2 > d1: |
|
state = neighbor |
|
out_p = d2 |
|
else: |
|
p = torch.exp(-(d2 - d1)/self.temp) |
|
if torch.rand(size=(1,)) < torch.squeeze(p): |
|
state = neighbor |
|
out_p = d1 |
|
out_p = d2 |
|
return state, out_p |
|
|
|
def prot_tokenizer(self): |
|
_, esm_alphabet = esm.pretrained.esm.pretrained.esm2_t33_650M_UR50D() |
|
bc = esm_alphabet.get_batch_converter() |
|
_, _, prot_tokens = bc([(1, self.target)]) |
|
prot_tokenized = torch.tensor(prot_tokens, dtype=torch.int64) |
|
|
|
prot_ex = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*esm_alphabet.padding_idx |
|
prot_ex[:, :prot_tokenized.shape[1]] = prot_tokenized |
|
return prot_ex.to(self.device) |
|
|
|
def apta_tokenizer(self, aptamer): |
|
return torch.tensor(rna2vec(np.array([aptamer])), dtype=torch.int64) |
|
|
|
def temp_scheduler(self, t): |
|
return self.temp*self.alpha**t |
|
|
|
def mutate(self, state): |
|
base_ind = np.random.choice(len(self.initial_state)) |
|
base = state[base_ind] |
|
cands = ['U', 'A', 'C', 'G'] |
|
choice = cands[0] |
|
while choice == base: |
|
choice = cands[np.random.choice(len(cands))] |
|
state = list(state) |
|
state[base_ind] = choice |
|
return "".join(state) |
|
|
|
def random_state_generator(self): |
|
cands = ['U', 'A', 'C', 'G'] |
|
return random.choices(cands, k=self.length) |
|
|