huypn16 commited on
Commit
2d24af3
1 Parent(s): 2c87c16

Update replay_main.py

Browse files
Files changed (1) hide show
  1. replay_main.py +22 -22
replay_main.py CHANGED
@@ -9,7 +9,7 @@ from src.trainer.reinforce import REINFORCETrainer, REINFORCEConfig
9
  from copy import deepcopy
10
  from tqdm import tqdm
11
  from src.utils import print_text
12
- from src.utils import Config
13
  from matplotlib import pyplot as plt
14
  import time
15
 
@@ -112,13 +112,19 @@ def main(args):
112
  while not reach_terminal:
113
  # try:
114
  print("searching for next step")
115
-
 
116
  # response_tensors, skip, current_step_response_tensors = respond_to_batch(old_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.k, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer)
117
 
118
  with torch.no_grad():
119
- response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1100, batch_size=6, do_sample=True, temperature=temp)
120
 
121
  proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in response_tensors]
 
 
 
 
 
122
  all_proposed_solutions.extend(proposed_solutions)
123
 
124
  current_step_indices = []
@@ -143,21 +149,9 @@ def main(args):
143
  all_proposed_rewards.extend(rewards)
144
 
145
  # choosing the base next step based cumulative lookahead rewards
146
- step_rewards = [reward[current_step:].mean() for reward in rewards]
147
  most_probable_base_next_step_tensor = current_step_response_tensors[step_rewards.index(max(step_rewards))]
148
 
149
- # # force resampling if 1/2 the sampled lookaheads are suck
150
- # is_rw_nan = 0
151
- # for i, reward in enumerate(step_rewards):
152
- # next_tokens = tokenizer.decode(current_step_response_tensors[i])
153
- # if (next_tokens != tokenizer.eos_token) and :
154
- # is_rw_nan += 1
155
-
156
- # if is_rw_nan >= 1:
157
- # print("We are resampling")
158
- # temp += 0.3
159
- # rp = 1.0
160
- # continue
161
  print("Most probable base step: ")
162
  print(tokenizer.decode(most_probable_base_next_step_tensor))
163
  print(step_rewards)
@@ -169,12 +163,17 @@ def main(args):
169
  step_rewards[id] = step_rewards[id_max]
170
 
171
  # pg update
172
- stats = reinforce_trainer.step([current_sol_prefix_tensor[0] for i in range(args.k)], [response_tensors[i] for i in range(args.k)], rewards, logged=False)
173
- print(stats)
174
-
 
 
 
 
 
175
  # pg evaluation
176
  with torch.no_grad():
177
- search_response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1100, batch_size=6, do_sample=True, temperature=0.9)
178
  search_proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in search_response_tensors]
179
  all_proposed_solutions.extend(search_proposed_solutions)
180
 
@@ -200,8 +199,9 @@ def main(args):
200
  all_proposed_rewards.extend(search_rewards)
201
 
202
  # choosing the pg search next step based cumulative lookahead rewards
203
- search_step_rewards = [reward[current_step:].mean() for reward in search_rewards]
204
  most_probable_search_next_step_tensor = current_search_step_response_tensors[search_step_rewards.index(max(search_step_rewards))]
 
205
  print("Most probable search step: ")
206
  print(tokenizer.decode(most_probable_search_next_step_tensor))
207
  print(search_step_rewards)
@@ -242,7 +242,7 @@ def main(args):
242
  solution = all_proposed_solutions[all_proposed_rewards.index(max(all_proposed_rewards))]
243
 
244
  _, prediction = math_evaluation(solution, example["answer"])
245
- is_passed, main_prediction = math_evaluation(main_solution, example["answer"])
246
  accuracy += is_passed
247
 
248
  print("Final solution: ", solution)
 
9
  from copy import deepcopy
10
  from tqdm import tqdm
11
  from src.utils import print_text
12
+ from src.utils import Config, count_repeated_ngrams
13
  from matplotlib import pyplot as plt
14
  import time
15
 
 
112
  while not reach_terminal:
113
  # try:
114
  print("searching for next step")
115
+ if current_step >= 30:
116
+ break
117
  # response_tensors, skip, current_step_response_tensors = respond_to_batch(old_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.k, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer)
118
 
119
  with torch.no_grad():
120
+ response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1200-len(current_sol_prefix_tensor[0]), batch_size=6, do_sample=True, temperature=temp)
121
 
122
  proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in response_tensors]
123
+ for id, solution in enumerate(proposed_solutions):
124
+ cnt = count_repeated_ngrams(solution, tokenizer, 3)
125
+ if cnt >= 10:
126
+ proposed_solutions
127
+
128
  all_proposed_solutions.extend(proposed_solutions)
129
 
130
  current_step_indices = []
 
149
  all_proposed_rewards.extend(rewards)
150
 
151
  # choosing the base next step based cumulative lookahead rewards
152
+ step_rewards = [(reward[current_step:].mean() + reward[current_step]) if (len(reward) > current_step) else reward.mean() for reward in rewards]
153
  most_probable_base_next_step_tensor = current_step_response_tensors[step_rewards.index(max(step_rewards))]
154
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  print("Most probable base step: ")
156
  print(tokenizer.decode(most_probable_base_next_step_tensor))
157
  print(step_rewards)
 
163
  step_rewards[id] = step_rewards[id_max]
164
 
165
  # pg update
166
+ try:
167
+ if not any([torch.isnan(reward).item() for reward in step_rewards]):
168
+ print([torch.isnan(reward) for reward in step_rewards])
169
+ print(any([torch.isnan(reward).item() for reward in step_rewards]))
170
+ stats = reinforce_trainer.step([current_sol_prefix_tensor[0] for i in range(args.k)], [response_tensors[i] for i in range(args.k)], rewards, logged=False)
171
+ print(stats)
172
+ except:
173
+ continue
174
  # pg evaluation
175
  with torch.no_grad():
176
+ search_response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1200-len(current_sol_prefix_tensor[0]), batch_size=6, do_sample=True, temperature=temp)
177
  search_proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in search_response_tensors]
178
  all_proposed_solutions.extend(search_proposed_solutions)
179
 
 
199
  all_proposed_rewards.extend(search_rewards)
200
 
201
  # choosing the pg search next step based cumulative lookahead rewards
202
+ search_step_rewards = [(reward[current_step:].mean() + reward[current_step]) if (len(reward) > current_step) else reward.mean() for reward in search_rewards]
203
  most_probable_search_next_step_tensor = current_search_step_response_tensors[search_step_rewards.index(max(search_step_rewards))]
204
+
205
  print("Most probable search step: ")
206
  print(tokenizer.decode(most_probable_search_next_step_tensor))
207
  print(search_step_rewards)
 
242
  solution = all_proposed_solutions[all_proposed_rewards.index(max(all_proposed_rewards))]
243
 
244
  _, prediction = math_evaluation(solution, example["answer"])
245
+ is_passed, main_prediction = math_evaluation(solution, example["answer"])
246
  accuracy += is_passed
247
 
248
  print("Final solution: ", solution)