from datasets import load_dataset from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model from transformers import AutoTokenizer, AutoModelForCausalLM import torch from peft import PromptTuningInit, PromptTuningConfig, TaskType, LoraConfig, AutoPeftModelForCausalLM, get_peft_model from argparse import ArgumentParser from src.core import respond_to_batch, get_rewards, math_evaluation from src.trainer.reinforce import REINFORCETrainer, REINFORCEConfig from copy import deepcopy from tqdm import tqdm from src.utils import print_text from src.utils import Config, count_repeated_ngrams from matplotlib import pyplot as plt import time def parse_args(): parser = ArgumentParser() parser.add_argument("--model_sft_name", type=str) parser.add_argument("--base_name", type=str, default="peiyi9979/mistral-7b-sft") parser.add_argument("--reward_model_name", type=str, default="peiyi9979/math-shepherd-mistral-7b-prm") parser.add_argument("--lora_rank", type=int, default=32) parser.add_argument("--k", type=int, default=6) parser.add_argument("--numeval", type=int, default=55) parser.add_argument("--verbose", action="store_true") parser.add_argument("--num-truncated-rollouts", type=int, default=4) parser.add_argument("--epsilon", type=float, default=0) parser.add_argument("-predictive_horizon", type=int, default=2) return parser.parse_args() def main(args): if args.verbose: Config.VERBOSE = True prm_dataset = load_dataset('Birchlabs/openai-prm800k-stepwise-critic') test = prm_dataset['test'] testset_with_labels = [] for i in range(len(test)): if test[i]["is_solution"]: testset_with_labels.append(test[i]) testset_with_labels = testset_with_labels[:args.numeval] lora_config = LoraConfig( peft_type=TaskType.CAUSAL_LM, r=args.lora_rank, lora_alpha=16 ) foundational_model = AutoModelForCausalLM.from_pretrained(args.base_name, torch_dtype=torch.bfloat16, use_cache=True, _attn_implementation="flash_attention_2").cpu() # foundational_model.load_adapter("/cm/shared/huypn16/ReinforceLLM-base_or_pgs_meta/ReinforceLLM-base_or_pgs/checkpoints/mistral-7b-reinforce-0-25/") tokenizer = AutoTokenizer.from_pretrained(args.base_name) tokenizer.pad_token = tokenizer.eos_token foundational_model.enable_input_require_grads() lora_config = LoraConfig( peft_type=TaskType.CAUSAL_LM, r=args.lora_rank, ) foundational_model = get_peft_model(foundational_model, lora_config) for name, param in foundational_model.named_parameters(): if "lm_head" in name: param.requires_grad = True elif "lora" in name: param.requires_grad = True else: param.requires_grad = False old_model = deepcopy(foundational_model) foundational_model = torch.nn.DataParallel(foundational_model) foundational_model.to('cuda:0') good_token = '+' bad_token = '-' step_tag = 'ΠΊΠΈ' rm_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name) candidate_tokens = rm_tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387] step_tag_id = rm_tokenizer.encode(f"{step_tag}")[-1] rm_model = AutoModelForCausalLM.from_pretrained(args.reward_model_name, torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2").eval() rm_model.to('cuda:1') # llama-3 tokenizer things step_tag_id_policy = tokenizer.encode(f"{step_tag}")[-1] reinforce_config = REINFORCEConfig(batch_size=args.k, mini_batch_size=6, step_tag_id=step_tag_id_policy, reinforce_epochs=4) reinforce_trainer = REINFORCETrainer(reinforce_config, foundational_model, tokenizer=tokenizer) accuracy = 0.0 cnt = 0 for example in tqdm(testset_with_labels): cnt += 1 if cnt < 6: continue problem_statement = example["instruction"] print("--------------------------We are solving this problem: ------------------") print(problem_statement) reach_terminal = False query_tensors = tokenizer.encode(problem_statement, return_tensors="pt").to('cuda:0') current_sol_prefix_tensor = query_tensors current_step = 0 all_proposed_solutions = [] all_proposed_rewards = [] temp = 0.7 rp = 0.5 while not reach_terminal: # try: print("searching for next step") if current_step >= 30: break # response_tensors, skip, current_step_response_tensors = respond_to_batch(old_model, step_tag_id_policy, tokenizer.eos_token_id, current_sol_prefix_tensor.repeat(args.k, 1), txt_len=1100, top_p=1.0, sampling_steps=args.num_truncated_rollouts, current_step=current_step, tokenizer=tokenizer) with torch.no_grad(): response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1200-len(current_sol_prefix_tensor[0]), batch_size=6, do_sample=True, temperature=temp) proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in response_tensors] for id, solution in enumerate(proposed_solutions): cnt = count_repeated_ngrams(solution, tokenizer, 3) if cnt >= 10: proposed_solutions all_proposed_solutions.extend(proposed_solutions) current_step_indices = [] for sample in response_tensors: if len((sample == step_tag_id).nonzero(as_tuple=True)[0])>=args.predictive_horizon: current_step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][args.predictive_horizon-1] else: current_step_tag_index = 100000 if tokenizer.eos_token_id in sample: stop_index = len((sample == tokenizer.eos_token_id).nonzero(as_tuple=True)[0]) if (stop_index - current_step_tag_index) == 1: current_step_tag_index = stop_index current_step_indices.append(current_step_tag_index) current_step_response_tensors = [response_tensors[i][:current_index+1] for i, current_index in zip(range(args.k), current_step_indices)] rm_model = rm_model.to('cuda:1') rewards = get_rewards(rm_model, rm_tokenizer, problem_statement, proposed_solutions, candidate_tokens, step_tag_id) rm_model = rm_model.to('cpu') # list of k tensor size of (truncated_number_of_actions,) rewards = [reward.to('cuda:1') for reward in rewards] all_proposed_rewards.extend(rewards) # choosing the base next step based cumulative lookahead rewards step_rewards = [(reward[current_step:].mean() + reward[current_step]) if (len(reward) > current_step) else reward.mean() for reward in rewards] most_probable_base_next_step_tensor = current_step_response_tensors[step_rewards.index(max(step_rewards))] print("Most probable base step: ") print(tokenizer.decode(most_probable_base_next_step_tensor)) print(step_rewards) id_max = step_rewards.index(max(step_rewards)) for id, reward in enumerate(step_rewards): if torch.isnan(reward): response_tensors[id] = response_tensors[id_max] step_rewards[id] = step_rewards[id_max] # pg update try: if not any([torch.isnan(reward).item() for reward in step_rewards]): print([torch.isnan(reward) for reward in step_rewards]) print(any([torch.isnan(reward).item() for reward in step_rewards])) stats = reinforce_trainer.step([current_sol_prefix_tensor[0] for i in range(args.k)], [response_tensors[i] for i in range(args.k)], rewards, logged=False) print(stats) except: continue # pg evaluation with torch.no_grad(): search_response_tensors = reinforce_trainer.generate([current_sol_prefix_tensor[0] for i in range(args.k)], return_prompt=False, max_new_tokens=1200-len(current_sol_prefix_tensor[0]), batch_size=6, do_sample=True, temperature=temp) search_proposed_solutions = [tokenizer.decode(torch.cat([current_sol_prefix_tensor[0], response_tensor], dim=-1), skip_special_tokens=True) for response_tensor in search_response_tensors] all_proposed_solutions.extend(search_proposed_solutions) current_step_indices = [] for sample in search_response_tensors: if len((sample == step_tag_id).nonzero(as_tuple=True)[0])>=args.predictive_horizon: current_step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][args.predictive_horizon-1] else: current_step_tag_index = 100000 if tokenizer.eos_token_id in sample: stop_index = len((sample == tokenizer.eos_token_id).nonzero(as_tuple=True)[0]) if (stop_index - current_step_tag_index) == 1: current_step_tag_index = stop_index current_step_indices.append(current_step_tag_index) current_search_step_response_tensors = [search_response_tensors[i][:current_index+1] for i, current_index in zip(range(args.k), current_step_indices)] # get rewards for search rollout rm_model = rm_model.to('cuda:1') search_rewards = get_rewards(rm_model, rm_tokenizer, problem_statement, search_proposed_solutions, candidate_tokens, step_tag_id) rm_model = rm_model.to("cpu") search_rewards = [reward.to('cuda:1') for reward in search_rewards] all_proposed_rewards.extend(search_rewards) # choosing the pg search next step based cumulative lookahead rewards search_step_rewards = [(reward[current_step:].mean() + reward[current_step]) if (len(reward) > current_step) else reward.mean() for reward in search_rewards] most_probable_search_next_step_tensor = current_search_step_response_tensors[search_step_rewards.index(max(search_step_rewards))] print("Most probable search step: ") print(tokenizer.decode(most_probable_search_next_step_tensor)) print(search_step_rewards) # choose the next step based on the base policy or PGS adapted policy most_probable_next_step_tensor = most_probable_base_next_step_tensor if max(step_rewards) > max(search_step_rewards) + args.epsilon else most_probable_search_next_step_tensor print("We choosed PGS adapted policy" if max(step_rewards) <= max(search_step_rewards) + args.epsilon else "We choosed base policy") print("Reward difference: ", max(step_rewards) - max(search_step_rewards)) # update the current solution prefix tensor current_sol_prefix_tensor = torch.cat([current_sol_prefix_tensor, most_probable_next_step_tensor.unsqueeze(0)], dim=-1) if tokenizer.eos_token_id in current_sol_prefix_tensor: reach_terminal = True # reset the model after each search time to ensure the model is not biased by the previous search # if max(step_rewards) > max(search_step_rewards) + args.epsilon: foundational_model = deepcopy(old_model) foundational_model = torch.nn.DataParallel(foundational_model) foundational_model.to('cuda:0') current_step += args.predictive_horizon temp += 0.15 del response_tensors, proposed_solutions, rewards # run evaluation solution = tokenizer.decode(current_sol_prefix_tensor[0], skip_special_tokens=True) main_solution = solution rm_model = rm_model.to('cuda:1') final_reward = get_rewards(rm_model, rm_tokenizer, problem_statement, [solution], candidate_tokens, step_tag_id) rm_model = rm_model.to('cpu') all_proposed_solutions.append(solution) all_proposed_rewards.append(final_reward[0]) all_proposed_rewards = [sum(reward[-5:])/len(reward[-5:]) if len(reward[-5:]) > 0 else 1 for reward in all_proposed_rewards] print(f"We have sampled {len(all_proposed_solutions)} solutions in total") solution = all_proposed_solutions[all_proposed_rewards.index(max(all_proposed_rewards))] _, prediction = math_evaluation(solution, example["answer"]) is_passed, main_prediction = math_evaluation(solution, example["answer"]) accuracy += is_passed print("Final solution: ", solution) print("Final main solution: ", main_solution) print("Final answer: ", main_prediction) print("Correct answer:", example["answer"]) print("Problem done with correct solution: ", is_passed) #reset the model foundational_model = deepcopy(old_model) foundational_model = torch.nn.DataParallel(foundational_model) foundational_model.to('cuda:0') print(f"Accuracy: {accuracy/(cnt-5)}") print(f"Accuracy: {accuracy/(len(testset_with_labels) - 5)}") if __name__ == '__main__': args = parse_args() main(args)