Update replay_main.py
Browse files- replay_main.py +22 -22
replay_main.py
CHANGED
@@ -9,7 +9,7 @@ from src.trainer.reinforce import REINFORCETrainer, REINFORCEConfig
|
|
9 |
from copy import deepcopy
|
10 |
from tqdm import tqdm
|
11 |
from src.utils import print_text
|
12 |
-
from src.utils import Config
|
13 |
from matplotlib import pyplot as plt
|
14 |
import time
|
15 |
|
@@ -112,13 +112,19 @@ def main(args):
|
|
112 |
while not reach_terminal:
|
113 |
# try:
|
114 |
print("searching for next step")
|
115 |
-
|
|
|
116 |
# response_tensors, skip, current_step_response_tensors = respond_to_batch(old_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.k, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer)
|
117 |
|
118 |
with torch.no_grad():
|
119 |
-
response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=
|
120 |
|
121 |
proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in response_tensors]
|
|
|
|
|
|
|
|
|
|
|
122 |
all_proposed_solutions.extend(proposed_solutions)
|
123 |
|
124 |
current_step_indices = []
|
@@ -143,21 +149,9 @@ def main(args):
|
|
143 |
all_proposed_rewards.extend(rewards)
|
144 |
|
145 |
# choosing the base next step based cumulative lookahead rewards
|
146 |
-
step_rewards = [reward[current_step:].mean() for reward in rewards]
|
147 |
most_probable_base_next_step_tensor = current_step_response_tensors[step_rewards.index(max(step_rewards))]
|
148 |
|
149 |
-
# # force resampling if 1/2 the sampled lookaheads are suck
|
150 |
-
# is_rw_nan = 0
|
151 |
-
# for i, reward in enumerate(step_rewards):
|
152 |
-
# next_tokens = tokenizer.decode(current_step_response_tensors[i])
|
153 |
-
# if (next_tokens != tokenizer.eos_token) and :
|
154 |
-
# is_rw_nan += 1
|
155 |
-
|
156 |
-
# if is_rw_nan >= 1:
|
157 |
-
# print("We are resampling")
|
158 |
-
# temp += 0.3
|
159 |
-
# rp = 1.0
|
160 |
-
# continue
|
161 |
print("Most probable base step: ")
|
162 |
print(tokenizer.decode(most_probable_base_next_step_tensor))
|
163 |
print(step_rewards)
|
@@ -169,12 +163,17 @@ def main(args):
|
|
169 |
step_rewards[id] = step_rewards[id_max]
|
170 |
|
171 |
# pg update
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
175 |
# pg evaluation
|
176 |
with torch.no_grad():
|
177 |
-
search_response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=
|
178 |
search_proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in search_response_tensors]
|
179 |
all_proposed_solutions.extend(search_proposed_solutions)
|
180 |
|
@@ -200,8 +199,9 @@ def main(args):
|
|
200 |
all_proposed_rewards.extend(search_rewards)
|
201 |
|
202 |
# choosing the pg search next step based cumulative lookahead rewards
|
203 |
-
search_step_rewards = [reward[current_step:].mean() for reward in search_rewards]
|
204 |
most_probable_search_next_step_tensor = current_search_step_response_tensors[search_step_rewards.index(max(search_step_rewards))]
|
|
|
205 |
print("Most probable search step: ")
|
206 |
print(tokenizer.decode(most_probable_search_next_step_tensor))
|
207 |
print(search_step_rewards)
|
@@ -242,7 +242,7 @@ def main(args):
|
|
242 |
solution = all_proposed_solutions[all_proposed_rewards.index(max(all_proposed_rewards))]
|
243 |
|
244 |
_, prediction = math_evaluation(solution, example["answer"])
|
245 |
-
is_passed, main_prediction = math_evaluation(
|
246 |
accuracy += is_passed
|
247 |
|
248 |
print("Final solution: ", solution)
|
|
|
9 |
from copy import deepcopy
|
10 |
from tqdm import tqdm
|
11 |
from src.utils import print_text
|
12 |
+
from src.utils import Config, count_repeated_ngrams
|
13 |
from matplotlib import pyplot as plt
|
14 |
import time
|
15 |
|
|
|
112 |
while not reach_terminal:
|
113 |
# try:
|
114 |
print("searching for next step")
|
115 |
+
if current_step >= 30:
|
116 |
+
break
|
117 |
# response_tensors, skip, current_step_response_tensors = respond_to_batch(old_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.k, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer)
|
118 |
|
119 |
with torch.no_grad():
|
120 |
+
response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1200-len(current_sol_prefix_tensor[0]), batch_size=6, do_sample=True, temperature=temp)
|
121 |
|
122 |
proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in response_tensors]
|
123 |
+
for id, solution in enumerate(proposed_solutions):
|
124 |
+
cnt = count_repeated_ngrams(solution, tokenizer, 3)
|
125 |
+
if cnt >= 10:
|
126 |
+
proposed_solutions
|
127 |
+
|
128 |
all_proposed_solutions.extend(proposed_solutions)
|
129 |
|
130 |
current_step_indices = []
|
|
|
149 |
all_proposed_rewards.extend(rewards)
|
150 |
|
151 |
# choosing the base next step based cumulative lookahead rewards
|
152 |
+
step_rewards = [(reward[current_step:].mean() + reward[current_step]) if (len(reward) > current_step) else reward.mean() for reward in rewards]
|
153 |
most_probable_base_next_step_tensor = current_step_response_tensors[step_rewards.index(max(step_rewards))]
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
print("Most probable base step: ")
|
156 |
print(tokenizer.decode(most_probable_base_next_step_tensor))
|
157 |
print(step_rewards)
|
|
|
163 |
step_rewards[id] = step_rewards[id_max]
|
164 |
|
165 |
# pg update
|
166 |
+
try:
|
167 |
+
if not any([torch.isnan(reward).item() for reward in step_rewards]):
|
168 |
+
print([torch.isnan(reward) for reward in step_rewards])
|
169 |
+
print(any([torch.isnan(reward).item() for reward in step_rewards]))
|
170 |
+
stats = reinforce_trainer.step([current_sol_prefix_tensor[0] for i in range(args.k)], [response_tensors[i] for i in range(args.k)], rewards, logged=False)
|
171 |
+
print(stats)
|
172 |
+
except:
|
173 |
+
continue
|
174 |
# pg evaluation
|
175 |
with torch.no_grad():
|
176 |
+
search_response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1200-len(current_sol_prefix_tensor[0]), batch_size=6, do_sample=True, temperature=temp)
|
177 |
search_proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in search_response_tensors]
|
178 |
all_proposed_solutions.extend(search_proposed_solutions)
|
179 |
|
|
|
199 |
all_proposed_rewards.extend(search_rewards)
|
200 |
|
201 |
# choosing the pg search next step based cumulative lookahead rewards
|
202 |
+
search_step_rewards = [(reward[current_step:].mean() + reward[current_step]) if (len(reward) > current_step) else reward.mean() for reward in search_rewards]
|
203 |
most_probable_search_next_step_tensor = current_search_step_response_tensors[search_step_rewards.index(max(search_step_rewards))]
|
204 |
+
|
205 |
print("Most probable search step: ")
|
206 |
print(tokenizer.decode(most_probable_search_next_step_tensor))
|
207 |
print(search_step_rewards)
|
|
|
242 |
solution = all_proposed_solutions[all_proposed_rewards.index(max(all_proposed_rewards))]
|
243 |
|
244 |
_, prediction = math_evaluation(solution, example["answer"])
|
245 |
+
is_passed, main_prediction = math_evaluation(solution, example["answer"])
|
246 |
accuracy += is_passed
|
247 |
|
248 |
print("Final solution: ", solution)
|