Atom Bioworks commited on
Commit
e4cb743
1 Parent(s): 2616ade

Create mcts.py

Browse files
Files changed (1) hide show
  1. mcts.py +208 -0
mcts.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import timeit
3
+ import torch
4
+ from utils import rna2vec
5
+ from transformers import AutoTokenizer
6
+
7
+ #Node
8
+ class Node:
9
+ #init
10
+ def __init__(self, letter="", parent=None, root=False, last=False, depth=0, states=8):
11
+ self.exploitation_score = 0 # Exploitaion score
12
+ self.visits = 1 #How many visits
13
+ self.letter = letter #This node's letter
14
+ self.parent = parent #This node's parent node
15
+ self.states = states #How many states in node
16
+ self.children = np.array([None for _ in range(self.states)]) #This node's children
17
+ self.children_stat = np.zeros(self.states, dtype=bool) #Which stat are expanded
18
+ self.root = root # Is root? boolean
19
+ self.last = last # Is last node?
20
+ self.depth = depth # My depth
21
+ self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"]
22
+
23
+
24
+ #next_node
25
+ def next_node(self, child=0): #Return next node
26
+ assert self.children_stat[child] == True, "No child in here."
27
+
28
+ return self.children[child]
29
+
30
+ #back_parent
31
+ def back_parent(self): #Go back to parent
32
+ return self.parent, letters_map[self.letter]
33
+
34
+ #generate_child
35
+ def generate_child(self, child=0, last=False): #Generate child
36
+ assert self.children_stat[child] == False, "Already tree generated child at here"
37
+
38
+ self.children[child] = Node(letter=self.letters[child], parent=self, last=last, depth=self.depth+1, states=self.states) #New node
39
+ self.children_stat[child] = True #Stat = True
40
+
41
+ return self.children[child]
42
+
43
+ #backpropagation
44
+ def backpropagation(self, score=0):
45
+ self.visits += 1 # +1 to visit
46
+ if self.root == True: # if root, then stop
47
+ return self.exploitation_score
48
+
49
+ else:
50
+ self.exploitation_score += score #Add score to exploitation score
51
+ return self.parent.backpropagation(score=score) #Backpropagation to parent node
52
+
53
+ #UCT
54
+ def UCT(self):
55
+ return (self.exploitation_score / self.visits) + np.sqrt(np.log(self.parent.visits) / (2 * self.visits)) #UCT score
56
+
57
+
58
+ #MCTS
59
+ class MCTS:
60
+ def __init__(self, target_encoded, depth=20, iteration=1000, states=8, target_protein="", device='cpu', esm_alphabet=None):
61
+ self.states = states #How many states
62
+ self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) #root node
63
+ self.depth = depth #Maximum depth
64
+ self.iteration = iteration #iteration for expand
65
+ self.target_protein = target_protein #target protein's amino acid sequence
66
+ self.device = device
67
+ self.encoded_targetprotein = target_encoded
68
+ self.base = ""
69
+ self.candidate = ""
70
+ self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"]
71
+ self.esm_alphabet = esm_alphabet
72
+ self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True)
73
+
74
+
75
+ def make_candidate(self, classifier):
76
+ now = self.root
77
+ n = 0 # rounds
78
+ start_time = timeit.default_timer() #timer start
79
+
80
+ while len(self.base) < self.depth * 2: #If now is last node, then stop
81
+ n += 1
82
+ print(n, "round start!!!")
83
+ for _ in range(self.iteration):
84
+ now = self.select(classifier, now=now) #Select & Expand
85
+
86
+ terminate_time = timeit.default_timer()
87
+ time = terminate_time-start_time
88
+
89
+ base = self.find_best_subsequence() #Find best subsequence
90
+ self.base = base
91
+
92
+ # print("best subsequence:", base)
93
+ # print("Depth:", int(len(base)/2))
94
+ # print("%02d:%02d:%2f" % ((time//3600), (time//60)%60, time%60))
95
+ # print("=" * 80)
96
+
97
+ self.root = Node(letter="", parent=None, root=True, last=False, states=self.states, depth=len(self.base)/2)
98
+ now = self.root
99
+
100
+ self.candidate = self.base
101
+
102
+ return self.candidate
103
+
104
+ #selection
105
+ def select(self, classifier, now=None):
106
+ if now.depth == self.depth: #If last node, then stop
107
+ return self.root
108
+
109
+ next_node = 0
110
+ if np.sum(now.children_stat) == self.states: #If every child is expanded, then go to best child
111
+ best = 0
112
+ for i in range(self.states):
113
+ if best < now.children[i].UCT():
114
+ next_node = i
115
+ best = now.children[i].UCT()
116
+
117
+ else: #If not, then random
118
+ next_node = np.random.randint(0, self.states)
119
+ if now.children_stat[next_node] == False: #If selected child is not expanded, then expand and simulate
120
+ next_node = self.expand(classifier, child=next_node, now=now)
121
+
122
+ return self.root #start iteration at this node
123
+
124
+ return now.next_node(child=next_node)
125
+
126
+ #expand
127
+ def expand(self, classifier, child=None, now=None):
128
+ last = False
129
+ if now.depth == (self.depth-1): #If depth of this node is maximum depth -1, then next node is last
130
+ last = True
131
+
132
+ expanded_node = now.generate_child(child=child, last=last) #Expand
133
+
134
+ score = self.simulate(classifier, target=expanded_node) #Simulate
135
+ expanded_node.backpropagation(score=score) #Backporpagation
136
+
137
+ return child
138
+
139
+ #simulate
140
+ def simulate(self, classifier, target=None):
141
+ now = target #Target node
142
+ sim_seq = ""
143
+
144
+ while now.root != True: #Parent's letters
145
+ sim_seq = now.letter + sim_seq
146
+ now = now.parent
147
+
148
+ sim_seq = self.base + sim_seq
149
+
150
+ for i in range((self.depth * 2) - len(sim_seq)): #Random child letters
151
+ r = np.random.randint(0,self.states)
152
+ sim_seq += self.letters[r]
153
+
154
+ sim_seq = self.reconstruct(sim_seq)
155
+ scores = []
156
+
157
+ classifier.eval().to('cuda')
158
+ with torch.no_grad():
159
+ sim_seq = np.array([sim_seq])
160
+
161
+ apta_toks = self.nt_tokenizer.batch_encode_plus(sim_seq, return_tensors='pt', padding='max_length', max_length=275)['input_ids']
162
+ apta_attention_mask = apta_toks != self.nt_tokenizer.pad_token_id
163
+ prot_attention_mask = self.encoded_targetprotein != self.esm_alphabet.padding_idx
164
+ score, _, _, _ = classifier(apta_toks.to('cuda'), self.encoded_targetprotein.to('cuda'), apta_attention_mask.to('cuda'), prot_attention_mask.to('cuda'))
165
+
166
+ return score
167
+
168
+ #recommend
169
+ def get_candidate(self):
170
+ return self.reconstruct(self.candidate)
171
+
172
+ def find_best_subsequence(self):
173
+ now = self.root
174
+ stop = False
175
+ base = self.base
176
+
177
+ for _ in range((self.depth*2) - len(base)):
178
+ best = 0
179
+ next_node = 0
180
+ for j in range(self.states):
181
+ if now.children_stat[j] == True:
182
+ if best < now.children[j].UCT():
183
+ next_node = j
184
+ best = now.children[j].UCT()
185
+
186
+ now = now.next_node(child=next_node)
187
+ base += now.letter
188
+
189
+ # if current node has no expanded children, stop reconstructing.
190
+ if np.sum(now.children_stat) == 0:
191
+ break
192
+
193
+ return base
194
+
195
+ #reconstruct
196
+ def reconstruct(self, seq=""):
197
+ r_seq = ""
198
+ for i in range(0, len(seq), 2):
199
+ if seq[i] == '_':
200
+ r_seq = r_seq + seq[i+1]
201
+ else:
202
+ r_seq = seq[i] + r_seq
203
+ return r_seq
204
+
205
+ def reset(self):
206
+ self.base = ""
207
+ self.candidate = ""
208
+ self.root = Node(letter="", parent=None, root=True, last=False, states=self.states)