huypn16 commited on
Commit
a71c743
1 Parent(s): 88a70a3

replay_main.py

Browse files
Files changed (1) hide show
  1. replay_main.py +188 -0
replay_main.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ from peft import PromptTuningInit, PromptTuningConfig, TaskType, LoraConfig, AutoPeftModelForCausalLM, get_peft_model
6
+ from argparse import ArgumentParser
7
+ from src.core import respond_to_batch, get_rewards, math_evaluation
8
+ from src.trainer.reinforce import REINFORCETrainer, REINFORCEConfig
9
+ from copy import deepcopy
10
+ from tqdm import tqdm
11
+ from src.utils import print_text
12
+ from src.utils import Config
13
+ from matplotlib import pyplot as plt
14
+ import time
15
+
16
+ def parse_args():
17
+ parser = ArgumentParser()
18
+ parser.add_argument("--model_sft_name", type=str)
19
+ parser.add_argument("--base_name", type=str, default="peiyi9979/mistral-7b-sft")
20
+ parser.add_argument("--reward_model_name", type=str, default="peiyi9979/math-shepherd-mistral-7b-prm")
21
+ parser.add_argument("--lora_rank", type=int, default=32)
22
+ parser.add_argument("--max_search_retries", type=int, default=6)
23
+ parser.add_argument("--k", type=int, default=8)
24
+ parser.add_argument("--h", type=int, default=4)
25
+ parser.add_argument("--numeval", type=int, default=55)
26
+ parser.add_argument("--verbose", action="store_true")
27
+ parser.add_argument("--threshold", type=float, default=0.90)
28
+ parser.add_argument("--num-truncated-rollouts", type=int, default=6)
29
+ return parser.parse_args()
30
+
31
+ def main(args):
32
+ if args.verbose:
33
+ Config.VERBOSE = True
34
+
35
+ prm_dataset = load_dataset('Birchlabs/openai-prm800k-stepwise-critic')
36
+ test = prm_dataset['test']
37
+ testset_with_labels = []
38
+ for i in range(len(test)):
39
+ if test[i]["is_solution"]:
40
+ testset_with_labels.append(test[i])
41
+ testset_with_labels = testset_with_labels[:args.numeval]
42
+
43
+ foundational_model = AutoModelForCausalLM.from_pretrained(args.base_name, torch_dtype=torch.bfloat16).cpu()
44
+ # foundational_model.load_adapter(args.model_sft_name)
45
+ tokenizer = AutoTokenizer.from_pretrained(args.base_name)
46
+ tokenizer.pad_token = tokenizer.eos_token
47
+ foundational_model.enable_input_require_grads()
48
+
49
+ lora_config = LoraConfig(
50
+ peft_type=TaskType.CAUSAL_LM,
51
+ r=args.lora_rank,
52
+ lora_alpha=16
53
+
54
+ )
55
+
56
+ foundational_model = AutoModelForCausalLM.from_pretrained(args.base_name, torch_dtype=torch.bfloat16, use_cache=True).cpu()
57
+ # foundational_model.load_adapter(args.model_sft_name)
58
+ tokenizer = AutoTokenizer.from_pretrained(args.base_name)
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+ foundational_model.enable_input_require_grads()
61
+
62
+ lora_config = LoraConfig(
63
+ peft_type=TaskType.CAUSAL_LM,
64
+ r=args.lora_rank,
65
+ )
66
+
67
+
68
+ foundational_model = get_peft_model(foundational_model, lora_config)
69
+
70
+ for name, param in foundational_model.named_parameters():
71
+ if "lm_head" in name:
72
+ param.requires_grad = True
73
+ elif "lora" in name:
74
+ param.requires_grad = True
75
+ else:
76
+ param.requires_grad = False
77
+
78
+ old_model = deepcopy(foundational_model)
79
+
80
+ foundational_model = torch.nn.DataParallel(foundational_model)
81
+ foundational_model.to('cuda:0')
82
+
83
+ good_token = '+'
84
+ bad_token = '-'
85
+ step_tag = 'ки'
86
+
87
+ rm_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name)
88
+ candidate_tokens = rm_tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
89
+ step_tag_id = rm_tokenizer.encode(f"{step_tag}")[-1]
90
+ rm_model = AutoModelForCausalLM.from_pretrained(args.reward_model_name).eval()
91
+ rm_model.to('cuda:1')
92
+
93
+ # llama-3 tokenizer things
94
+ step_tag_id_policy = tokenizer.encode(f"{step_tag}")[-1]
95
+
96
+ reinforce_config = REINFORCEConfig(batch_size=args.k, mini_batch_size=4, step_tag_id=step_tag_id_policy, reinforce_epochs=6)
97
+ reinforce_trainer = REINFORCETrainer(reinforce_config, foundational_model, tokenizer=tokenizer)
98
+ accuracy = 0.0
99
+ cnt = 0
100
+
101
+ for example in tqdm(testset_with_labels):
102
+ cnt += 1
103
+ if cnt < 12:
104
+ continue
105
+ problem_statement = example["instruction"]
106
+ print("--------------------------We are solving this problem: ------------------")
107
+ print(problem_statement)
108
+ reach_terminal = False
109
+ query_tensors = tokenizer.encode(problem_statement, return_tensors="pt").to('cuda:0')
110
+ current_sol_prefix_tensor = query_tensors
111
+ current_step = 0
112
+ while not reach_terminal:
113
+ # try:
114
+ print("searching for next step")
115
+ response_tensors, skip, current_step_response_tensors = respond_to_batch(foundational_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.k, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer)
116
+ proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in response_tensors]
117
+
118
+ rm_model = rm_model.to('cuda:1')
119
+ rewards = get_rewards(rm_model, rm_tokenizer, problem_statement, proposed_solutions, candidate_tokens, step_tag_id)
120
+ rm_model = rm_model.to('cpu')
121
+ # list of k tensor size of (truncated_number_of_actions,)
122
+ rewards = [reward.to('cuda:1') for reward in rewards]
123
+
124
+ step_rewards = [reward[current_step:].mean() for reward in rewards]
125
+ most_probable_base_next_step_tensor = current_step_response_tensors[step_rewards.index(max(step_rewards))]
126
+ print(tokenizer.decode(most_probable_base_next_step_tensor))
127
+ print(step_rewards)
128
+
129
+ # pg rollout
130
+ stats = reinforce_trainer.step([current_sol_prefix_tensor[0] for i in range(args.k)], [response_tensors[i] for i in range(args.k)], rewards, logged=False)
131
+ print(stats)
132
+
133
+ # pg evaluation
134
+ search_response_tensors, skip, current_search_step_response_tensors = respond_to_batch(foundational_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.h, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer)
135
+ search_proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in search_response_tensors]
136
+
137
+ rm_model = rm_model.to('cuda:1')
138
+ search_rewards = get_rewards(rm_model, rm_tokenizer, problem_statement, search_proposed_solutions, candidate_tokens, step_tag_id)
139
+ rm_model = rm_model.to("cpu")
140
+ search_rewards = [reward.to('cuda:1') for reward in search_rewards]
141
+ search_step_rewards = [reward[current_step:].mean() for reward in search_rewards]
142
+ most_probable_search_next_step_tensor = current_search_step_response_tensors[search_step_rewards.index(max(search_step_rewards))]
143
+ print(tokenizer.decode(most_probable_search_next_step_tensor))
144
+ print(search_step_rewards)
145
+
146
+ # choose the next step based on the base policy or PGS adapted policy
147
+ most_probable_next_step_tensor = most_probable_base_next_step_tensor if max(step_rewards) > max(search_step_rewards) else most_probable_search_next_step_tensor
148
+ print("We choosed PGS adapted policy" if max(step_rewards) <= max(search_step_rewards) else "We choosed base policy")
149
+
150
+ current_sol_prefix_tensor = torch.cat([current_sol_prefix_tensor, most_probable_next_step_tensor.unsqueeze(0)], dim=-1)
151
+
152
+ if tokenizer.eos_token_id in current_sol_prefix_tensor:
153
+ reach_terminal = True
154
+
155
+ # reset the model after each search time
156
+ foundational_model = deepcopy(old_model)
157
+ foundational_model = torch.nn.DataParallel(foundational_model)
158
+ foundational_model.to('cuda:0')
159
+
160
+ current_step += 1
161
+
162
+ del response_tensors, proposed_solutions, rewards
163
+
164
+ # run evaluation
165
+ solution = tokenizer.decode(current_sol_prefix_tensor[0], skip_special_tokens=True)
166
+
167
+ is_passed, prediction = math_evaluation(solution, example["answer"])
168
+ accuracy += is_passed
169
+
170
+ print("Final solution: ", solution)
171
+ print("Final answer: ", prediction)
172
+ print("Correct answer:", example["answer"])
173
+ print("Problem done with correct solution: ", is_passed)
174
+
175
+
176
+ #reset the model
177
+ foundational_model = deepcopy(old_model)
178
+ foundational_model = torch.nn.DataParallel(foundational_model)
179
+ foundational_model.to('cuda:0')
180
+
181
+ print(f"Accuracy: {accuracy/(cnt-5)}")
182
+
183
+ print(f"Accuracy: {accuracy/(len(testset_with_labels) - 5)}")
184
+
185
+
186
+ if __name__ == '__main__':
187
+ args = parse_args()
188
+ main(args)