File size: 7,898 Bytes
e4cb743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import numpy as np
import timeit
import torch
from transformers import AutoTokenizer

#Node
class Node:
    #init
    def __init__(self, letter="", parent=None, root=False, last=False, depth=0, states=8):
        self.exploitation_score = 0 # Exploitaion score
        self.visits = 1 #How many visits
        self.letter = letter #This node's letter
        self.parent = parent #This node's parent node
        self.states = states #How many states in node
        self.children = np.array([None for _ in range(self.states)]) #This node's children
        self.children_stat = np.zeros(self.states, dtype=bool) #Which stat are expanded
        self.root = root # Is root? boolean
        self.last = last # Is last node?
        self.depth = depth # My depth
        self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"]

    
    #next_node
    def next_node(self, child=0): #Return next node
        assert self.children_stat[child] == True, "No child in here."
        
        return self.children[child]
    
    #back_parent
    def back_parent(self): #Go back to parent
        return self.parent, letters_map[self.letter]
    
    #generate_child
    def generate_child(self, child=0, last=False): #Generate child
        assert self.children_stat[child] == False, "Already tree generated child at here"
        
        self.children[child] = Node(letter=self.letters[child], parent=self, last=last, depth=self.depth+1, states=self.states) #New node
        self.children_stat[child] = True #Stat = True
        
        return self.children[child]
    
    #backpropagation
    def backpropagation(self, score=0):
        self.visits += 1 # +1 to visit
        if self.root == True: # if root, then stop
            return self.exploitation_score
        
        else:
            self.exploitation_score += score #Add score to exploitation score
            return self.parent.backpropagation(score=score) #Backpropagation to parent node
    
    #UCT
    def UCT(self):
        return (self.exploitation_score / self.visits) + np.sqrt(np.log(self.parent.visits) / (2 * self.visits)) #UCT score
    

#MCTS
class MCTS:
    def __init__(self, target_encoded, depth=20, iteration=1000, states=8, target_protein="", device='cpu', esm_alphabet=None):
        self.states = states #How many states
        self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) #root node
        self.depth = depth #Maximum depth
        self.iteration = iteration #iteration for expand 
        self.target_protein = target_protein #target protein's amino acid sequence
        self.device = device
        self.encoded_targetprotein = target_encoded
        self.base = ""
        self.candidate = ""
        self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"]
        self.esm_alphabet = esm_alphabet
        self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True)


    def make_candidate(self, classifier):
        now = self.root
        n = 0 # rounds
        start_time = timeit.default_timer() #timer start
        
        while len(self.base) < self.depth * 2: #If now is last node, then stop
            n += 1
            print(n, "round start!!!")
            for _ in range(self.iteration):
                now = self.select(classifier, now=now) #Select & Expand

            terminate_time = timeit.default_timer()
            time = terminate_time-start_time
            
            base = self.find_best_subsequence() #Find best subsequence
            self.base = base

            # print("best subsequence:", base)
            # print("Depth:", int(len(base)/2))
            # print("%02d:%02d:%2f" % ((time//3600), (time//60)%60, time%60))
            # print("=" * 80)

            self.root = Node(letter="", parent=None, root=True, last=False, states=self.states, depth=len(self.base)/2)
            now = self.root
            
        self.candidate = self.base
        
        return self.candidate
            
    #selection
    def select(self, classifier, now=None):
        if now.depth == self.depth: #If last node, then stop
            return self.root
        
        next_node = 0
        if np.sum(now.children_stat) == self.states: #If every child is expanded, then go to best child
            best = 0
            for i in range(self.states):
                if best < now.children[i].UCT():
                    next_node = i
                    best = now.children[i].UCT()
                    
        else: #If not, then random
            next_node = np.random.randint(0, self.states)
            if now.children_stat[next_node] == False: #If selected child is not expanded, then expand and simulate
                next_node = self.expand(classifier, child=next_node, now=now)
    
                return self.root #start iteration at this node
            
        return now.next_node(child=next_node)
    
    #expand
    def expand(self, classifier, child=None, now=None):
        last = False
        if now.depth == (self.depth-1): #If depth of this node is maximum depth -1, then next node is last
            last = True
        
        expanded_node = now.generate_child(child=child, last=last) #Expand
        
        score = self.simulate(classifier, target=expanded_node)  #Simulate
        expanded_node.backpropagation(score=score) #Backporpagation
        
        return child
    
    #simulate
    def simulate(self, classifier, target=None):
        now = target #Target node
        sim_seq = ""
        
        while now.root != True: #Parent's letters
            sim_seq = now.letter + sim_seq
            now = now.parent
            
        sim_seq = self.base + sim_seq
        
        for i in range((self.depth * 2) - len(sim_seq)): #Random child letters
            r = np.random.randint(0,self.states)
            sim_seq += self.letters[r]
        
        sim_seq = self.reconstruct(sim_seq)
        scores = []

        classifier.eval().to('cuda')
        with torch.no_grad():
            sim_seq = np.array([sim_seq])

            apta_toks = self.nt_tokenizer.batch_encode_plus(sim_seq, return_tensors='pt', padding='max_length', max_length=275)['input_ids']
            apta_attention_mask = apta_toks != self.nt_tokenizer.pad_token_id
            prot_attention_mask = self.encoded_targetprotein != self.esm_alphabet.padding_idx
            score, _, _, _ = classifier(apta_toks.to('cuda'), self.encoded_targetprotein.to('cuda'), apta_attention_mask.to('cuda'), prot_attention_mask.to('cuda'))

        return score
    
    #recommend
    def get_candidate(self):
        return self.reconstruct(self.candidate)
    
    def find_best_subsequence(self):
        now = self.root
        stop = False
        base = self.base
        
        for _ in range((self.depth*2) - len(base)):
            best = 0
            next_node = 0
            for j in range(self.states):
                if now.children_stat[j] == True:
                    if best < now.children[j].UCT():
                        next_node = j
                        best = now.children[j].UCT()

            now = now.next_node(child=next_node)
            base += now.letter
            
            # if current node has no expanded children, stop reconstructing.
            if np.sum(now.children_stat) == 0:
                break
                
        return base
            
    #reconstruct
    def reconstruct(self, seq=""):
        r_seq = ""
        for i in range(0, len(seq), 2):
            if seq[i] == '_':
                r_seq = r_seq + seq[i+1]
            else:
                r_seq = seq[i] + r_seq
        return r_seq
    
    def reset(self):
        self.base = ""
        self.candidate = ""
        self.root = Node(letter="", parent=None, root=True, last=False, states=self.states)