|
import numpy as np |
|
import timeit |
|
import torch |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class Node: |
|
|
|
def __init__(self, letter="", parent=None, root=False, last=False, depth=0, states=8): |
|
self.exploitation_score = 0 |
|
self.visits = 1 |
|
self.letter = letter |
|
self.parent = parent |
|
self.states = states |
|
self.children = np.array([None for _ in range(self.states)]) |
|
self.children_stat = np.zeros(self.states, dtype=bool) |
|
self.root = root |
|
self.last = last |
|
self.depth = depth |
|
self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"] |
|
|
|
|
|
|
|
def next_node(self, child=0): |
|
assert self.children_stat[child] == True, "No child in here." |
|
|
|
return self.children[child] |
|
|
|
|
|
def back_parent(self): |
|
return self.parent, letters_map[self.letter] |
|
|
|
|
|
def generate_child(self, child=0, last=False): |
|
assert self.children_stat[child] == False, "Already tree generated child at here" |
|
|
|
self.children[child] = Node(letter=self.letters[child], parent=self, last=last, depth=self.depth+1, states=self.states) |
|
self.children_stat[child] = True |
|
|
|
return self.children[child] |
|
|
|
|
|
def backpropagation(self, score=0): |
|
self.visits += 1 |
|
if self.root == True: |
|
return self.exploitation_score |
|
|
|
else: |
|
self.exploitation_score += score |
|
return self.parent.backpropagation(score=score) |
|
|
|
|
|
def UCT(self): |
|
return (self.exploitation_score / self.visits) + np.sqrt(np.log(self.parent.visits) / (2 * self.visits)) |
|
|
|
|
|
|
|
class MCTS: |
|
def __init__(self, target_encoded, depth=20, iteration=1000, states=8, target_protein="", device='cpu', esm_alphabet=None): |
|
self.states = states |
|
self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) |
|
self.depth = depth |
|
self.iteration = iteration |
|
self.target_protein = target_protein |
|
self.device = device |
|
self.encoded_targetprotein = target_encoded |
|
self.base = "" |
|
self.candidate = "" |
|
self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"] |
|
self.esm_alphabet = esm_alphabet |
|
self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True) |
|
|
|
|
|
def make_candidate(self, classifier): |
|
now = self.root |
|
n = 0 |
|
start_time = timeit.default_timer() |
|
|
|
while len(self.base) < self.depth * 2: |
|
n += 1 |
|
print(n, "round start!!!") |
|
for _ in range(self.iteration): |
|
now = self.select(classifier, now=now) |
|
|
|
terminate_time = timeit.default_timer() |
|
time = terminate_time-start_time |
|
|
|
base = self.find_best_subsequence() |
|
self.base = base |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.root = Node(letter="", parent=None, root=True, last=False, states=self.states, depth=len(self.base)/2) |
|
now = self.root |
|
|
|
self.candidate = self.base |
|
|
|
return self.candidate |
|
|
|
|
|
def select(self, classifier, now=None): |
|
if now.depth == self.depth: |
|
return self.root |
|
|
|
next_node = 0 |
|
if np.sum(now.children_stat) == self.states: |
|
best = 0 |
|
for i in range(self.states): |
|
if best < now.children[i].UCT(): |
|
next_node = i |
|
best = now.children[i].UCT() |
|
|
|
else: |
|
next_node = np.random.randint(0, self.states) |
|
if now.children_stat[next_node] == False: |
|
next_node = self.expand(classifier, child=next_node, now=now) |
|
|
|
return self.root |
|
|
|
return now.next_node(child=next_node) |
|
|
|
|
|
def expand(self, classifier, child=None, now=None): |
|
last = False |
|
if now.depth == (self.depth-1): |
|
last = True |
|
|
|
expanded_node = now.generate_child(child=child, last=last) |
|
|
|
score = self.simulate(classifier, target=expanded_node) |
|
expanded_node.backpropagation(score=score) |
|
|
|
return child |
|
|
|
|
|
def simulate(self, classifier, target=None): |
|
now = target |
|
sim_seq = "" |
|
|
|
while now.root != True: |
|
sim_seq = now.letter + sim_seq |
|
now = now.parent |
|
|
|
sim_seq = self.base + sim_seq |
|
|
|
for i in range((self.depth * 2) - len(sim_seq)): |
|
r = np.random.randint(0,self.states) |
|
sim_seq += self.letters[r] |
|
|
|
sim_seq = self.reconstruct(sim_seq) |
|
scores = [] |
|
|
|
classifier.eval().to('cuda') |
|
with torch.no_grad(): |
|
sim_seq = np.array([sim_seq]) |
|
|
|
apta_toks = self.nt_tokenizer.batch_encode_plus(sim_seq, return_tensors='pt', padding='max_length', max_length=275)['input_ids'] |
|
apta_attention_mask = apta_toks != self.nt_tokenizer.pad_token_id |
|
prot_attention_mask = self.encoded_targetprotein != self.esm_alphabet.padding_idx |
|
score, _, _, _ = classifier(apta_toks.to('cuda'), self.encoded_targetprotein.to('cuda'), apta_attention_mask.to('cuda'), prot_attention_mask.to('cuda')) |
|
|
|
return score |
|
|
|
|
|
def get_candidate(self): |
|
return self.reconstruct(self.candidate) |
|
|
|
def find_best_subsequence(self): |
|
now = self.root |
|
stop = False |
|
base = self.base |
|
|
|
for _ in range((self.depth*2) - len(base)): |
|
best = 0 |
|
next_node = 0 |
|
for j in range(self.states): |
|
if now.children_stat[j] == True: |
|
if best < now.children[j].UCT(): |
|
next_node = j |
|
best = now.children[j].UCT() |
|
|
|
now = now.next_node(child=next_node) |
|
base += now.letter |
|
|
|
|
|
if np.sum(now.children_stat) == 0: |
|
break |
|
|
|
return base |
|
|
|
|
|
def reconstruct(self, seq=""): |
|
r_seq = "" |
|
for i in range(0, len(seq), 2): |
|
if seq[i] == '_': |
|
r_seq = r_seq + seq[i+1] |
|
else: |
|
r_seq = seq[i] + r_seq |
|
return r_seq |
|
|
|
def reset(self): |
|
self.base = "" |
|
self.candidate = "" |
|
self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) |