Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +10 -0
- .vscode/settings.json +5 -0
- README.md +110 -0
- datasets/hh-rlhf/.gitattributes +54 -0
- datasets/hh-rlhf/README.md +68 -0
- datasets/hh-rlhf/harmless-base/test.jsonl.gz +3 -0
- datasets/hh-rlhf/harmless-base/train.jsonl.gz +3 -0
- datasets/hh-rlhf/helpful-base/test.jsonl.gz +3 -0
- datasets/hh-rlhf/helpful-base/train.jsonl.gz +3 -0
- datasets/hh-rlhf/helpful-online/test.jsonl.gz +3 -0
- datasets/hh-rlhf/helpful-online/train.jsonl.gz +3 -0
- datasets/hh-rlhf/helpful-rejection-sampled/test.jsonl.gz +3 -0
- datasets/hh-rlhf/helpful-rejection-sampled/train.jsonl.gz +3 -0
- datasets/hh-rlhf/red-team-attempts/red_team_attempts.jsonl.gz +3 -0
- export_hf_checkpoint.py +66 -0
- fix_trl/utils.py +739 -0
- plot.py +103 -0
- ppo/__pycache__/multi_reward_models.cpython-310.pyc +0 -0
- ppo/__pycache__/utils.cpython-310.pyc +0 -0
- ppo/eval_ppo_single_model.py +177 -0
- ppo/eval_rewarded_soups.py +220 -0
- ppo/morlhf-llama3.py +293 -0
- ppo/morlhf.py +277 -0
- ppo/multi_reward_models.py +64 -0
- ppo/ppo.py +313 -0
- ppo/ppo_reft.py +313 -0
- ppo/ppo_reft_fine_grained.py +336 -0
- ppo/scripts/load.sh +31 -0
- ppo/scripts/morlhf_llama3-inst.sh +15 -0
- ppo/scripts/morlhf_llama3-preference.sh +37 -0
- ppo/scripts/morlhf_llama3-preference1.sh +32 -0
- ppo/scripts/morlhf_llama3-preference2.sh +32 -0
- ppo/scripts/morlhf_llama3.sh +17 -0
- ppo/scripts/ppo_reft-fine_grained.sh +17 -0
- ppo/scripts/ppo_reft-fine_grained2.sh +18 -0
- ppo/scripts/ppo_reft.sh +18 -0
- ppo/scripts/ppo_single_vector.sh +17 -0
- ppo/scripts/run.sh +2 -0
- ppo/scripts/test.sh +17 -0
- ppo/test.ipynb +248 -0
- ppo/test_batch_decode.ipynb +294 -0
- ppo/utils.py +411 -0
- push_model.sh +8 -0
- push_with_hf.sh +1 -0
- requirements.txt +22 -0
- ric/evaluation.py +234 -0
- ric/generation.py +156 -0
- ric/main.py +155 -0
- ric/multi_reward_models.py +64 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ric_txt2img/docs/source/en/imgs/access_request.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ric_txt2img/results/*
|
| 2 |
+
ric_txt2img/pretrained/*
|
| 3 |
+
images/*
|
| 4 |
+
ric_txt2img/requirements.txt
|
| 5 |
+
ric_txt2img/eval_jpeg_aes_tmp.sh
|
| 6 |
+
download_data.py
|
| 7 |
+
check_data.py
|
| 8 |
+
process_data.py
|
| 9 |
+
ric_txt2img/data/*
|
| 10 |
+
ric_txt2img/experiments/
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
| 3 |
+
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
| 4 |
+
"python-envs.pythonProjects": []
|
| 5 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RiC: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment
|
| 2 |
+
|
| 3 |
+
Code for the ICML 2024 paper "Rewards-in-Context: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment". This repo is based on [trl](https://github.com/huggingface/trl).
|
| 4 |
+
|
| 5 |
+
This readme file is for LLM tasks, please refer to the 'ric_txt2img' directory for the code for conditional diffusion model training.
|
| 6 |
+
|
| 7 |
+
**Note:** We have updated the code to include the EOS token in data preparation for SFT and RiC, which helps resolve the issue of non-ending generation. As a result, the model can produce more natural responses, although the results may differ from the paper. If you wish to achieve the exact performance in the paper, please use an old version prior to November 3, 2024.
|
| 8 |
+
|
| 9 |
+
## 1. Installation
|
| 10 |
+
Install the requirements.
|
| 11 |
+
```
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
The following issue has been fixed using token_ids directly for response_template as in https://huggingface.co/docs/trl/main/en/sft_trainer#train-on-completions-only.
|
| 16 |
+
```
|
| 17 |
+
Fix the trl package.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
The trl package currently has an issue with its 'DataCollatorForCompletionOnlyLM' implementation, wherein the first token of the template may be nonsensical and unmatched. Hence, it is necessary to make modifications to the trl/trainer/utils.py file within your trl package (usually under the path {path to conda or python}/lib/python3/site-packages), specifically the 'torch_call' function starting from line 119. Upon installing the trl package based on the requirements.txt, you can easily replace the trl/trainer/utils.py with the utils.py file located in the 'fix_trl' directory.
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## 2. Usage
|
| 24 |
+
### 2.1 RiC
|
| 25 |
+
```
|
| 26 |
+
Note: To reproduce the results in the paper, ensure the same training steps, for example, (4 processes, batchsize=2, steps=20000), and load the model in 8bit (--load_in_8bit True). We have observed notable performance differences between 8-bit and 16-bit models, particularly in the summary task.
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
* **Preparing datasets**: the first step is to generate dataset for RiC training.
|
| 31 |
+
```
|
| 32 |
+
cd ./ric
|
| 33 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch prepare_dataset_with_rewards.py --reward_names 'harmless,helpful' --exp_type 'assistant' --save_directory './datasets/train_harmhelp.hf'
|
| 34 |
+
```
|
| 35 |
+
```
|
| 36 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch prepare_dataset_with_rewards.py --reward_names 'summary,faithful' --exp_type 'summary' --save_directory './datasets/summary_pref1faithful.hf'
|
| 37 |
+
```
|
| 38 |
+
Here, 'exp_type' includes two types of tasks, i.e., 'assistant' and 'summary'. Moreover, 'reward_names' can be sampled from \{'harmless', 'helpful', 'humor', 'summary', 'deberta', 'faithful'\}, where the first three are used for the 'assistant' task while the last three are used for the 'summary' task.
|
| 39 |
+
|
| 40 |
+
* **Training**: you can train RiC through the main.py file.
|
| 41 |
+
|
| 42 |
+
Offline training:
|
| 43 |
+
```
|
| 44 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch main.py --train_dataset_path {path_to_corresponding_dataset} --exp_type 'assistant' --reward_names 'harmless,helpful' --training_steps 20000 --num_online_iterations 0 --wandb_name 'ric_assistant_harmlesshelpful_offline20000' --batch_size 2 --load_in_8bit True
|
| 45 |
+
```
|
| 46 |
+
Offline training + online training:
|
| 47 |
+
```
|
| 48 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch main.py --train_dataset_path {path_to_corresponding_dataset} --exp_type 'assistant' --reward_names 'harmless,helpful' --training_steps 20000 --online_training_steps 4000 --num_online_iterations 2 --wandb_name 'ric_assistant_harmlesshelpful_offline20000_onlineiter2' --batch_size 2 --load_in_8bit True
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
* **Evaluation**: After training, models are saved into 'save_directory' path. You can evaluate models using the evaluation.py.
|
| 52 |
+
```
|
| 53 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch evaluation.py --peft_name {path_to_the_saved_peft} --reward_names 'harmless,helpful' --exp_type 'assistant' --wandb_name {path_of_the_experiment}
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### 2.2 SFT
|
| 57 |
+
* Training a SFT model:
|
| 58 |
+
```
|
| 59 |
+
cd ./sft
|
| 60 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch sft.py --base_model_name 'meta-llama/Llama-2-7b-hf' --exp_type 'summary' --wandb_name {name_of_the_experiment}
|
| 61 |
+
```
|
| 62 |
+
* Evaluating a SFT model:
|
| 63 |
+
```
|
| 64 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch eval_multiobjective.py --base_model_name {path_to_the_saved_SFT_model} --exp_type 'summary' --wandb_name {name_of_the_experiment} --reward_names 'summary,faithful'
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
* Merging the SFT model if using lora for finetuning:
|
| 68 |
+
```
|
| 69 |
+
python3 merge_sft_lora.py --base_model_name 'meta-llama/Llama-2-7b-hf' --lora_name {path_to_the_lora_file} --save_directory {path_to_the_saving_directory}
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 2.3 MORLHF
|
| 73 |
+
MORLHF optimizes preference weighted rewards using PPO, based on the SFT model in 2.2.
|
| 74 |
+
|
| 75 |
+
Train MORLHF:
|
| 76 |
+
```
|
| 77 |
+
cd ./ppo
|
| 78 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch morlhf.py --preference 0.5 --base_model_name {path_to_the_sft_model} --reward_names 'harmless,helpful' --exp_type 'assistant' --wandb_name 'morlhf_llamma2'
|
| 79 |
+
```
|
| 80 |
+
Here, the 'preference' is the preference for the first reward.
|
| 81 |
+
|
| 82 |
+
Evaluate MORLHF model:
|
| 83 |
+
```
|
| 84 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch eval_ppo_single_model.py --base_model_name {path_to_the_saved_morlhf_model} --reward_names 'harmless,helpful' --exp_type 'assistant' --wandb_name 'eval_morlhf_llamma2'
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### 2.4 Rewarded Soups
|
| 88 |
+
Rewarded Soups train $N$ PPO models for $N$ reward models.
|
| 89 |
+
|
| 90 |
+
Train PPO model for each reward model:
|
| 91 |
+
```
|
| 92 |
+
cd ./ppo
|
| 93 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch ppo.py --reward_name 'harmless' --base_model_name {path_to_the_SFT_model} --exp_type 'assistant' --wandb_name {name_for_this_experiment}
|
| 94 |
+
```
|
| 95 |
+
Evaluate Rewarded Soups with multiple PPO models:
|
| 96 |
+
```
|
| 97 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch eval_rewarded_soups.py --reward_names 'harmless,helpful' --base_model_path1 {path_to_the_PPO_model_for_harmless} --base_model_path2 {path_to_the_PPO_model_for_helpful} --exp_type 'assistant' --wandb_name 'eval_pposoups_harmless_helpful'
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
## 3. Citation
|
| 102 |
+
If you find our work useful for your research, please cite:
|
| 103 |
+
```
|
| 104 |
+
@article{yang2024rewards,
|
| 105 |
+
title={Rewards-in-Context: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment},
|
| 106 |
+
author={Yang, Rui and Pan, Xiaoman and Luo, Feng and Qiu, Shuang and Zhong, Han and Yu, Dong and Chen, Jianshu},
|
| 107 |
+
journal={International Conference on Machine Learning},
|
| 108 |
+
year={2024}
|
| 109 |
+
}
|
| 110 |
+
```
|
datasets/hh-rlhf/.gitattributes
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
# Audio files - uncompressed
|
| 37 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
# Audio files - compressed
|
| 41 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
# Image files - uncompressed
|
| 47 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
# Image files - compressed
|
| 52 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
datasets/hh-rlhf/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- human-feedback
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# Dataset Card for HH-RLHF
|
| 8 |
+
|
| 9 |
+
## Dataset Summary
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
This repository provides access to two different kinds of data:
|
| 13 |
+
1. Human preference data about helpfulness and harmlessness from [Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback](https://arxiv.org/abs/2204.05862). These data are meant to train preference (or reward) models for subsequent RLHF training. These data are *not* meant for supervised training of dialogue agents. Training dialogue agents on these data is likely to lead to harmful models and this shold be avoided.
|
| 14 |
+
2. Human-generated and annotated red teaming dialogues from [Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned](https://www.anthropic.com/red_teaming.pdf). These data are meant to understand how crowdworkers red team models and what types of red team attacks are succesful or not. The data are *not* meant for fine-tuning or preference modeling (use the data above for preference modeling). These data are entire transcripts of conversations that are derived from the harmlessness preference modeling data described above, where only the chosen response is incorporated into the overall transcript. Furthermore, the transcripts are annotated with human and automated measurements of how harmful the overall dialogues are.
|
| 15 |
+
|
| 16 |
+
**Disclaimer**: The data (especially the harmlessness preference data and the red team data) contain content that may be offensive or upsetting. Topics include, but are not limited to, discriminatory language and discussions of abuse, violence, self-harm, exploitation, and other potentially upsetting subject matter. Please only engage with the data in accordance with your own personal risk tolerance. The data are intended for research purposes, especially research that can make models *less* harmful. The views expressed in the data do not reflect the views of Anthropic or any of its employees. As mentioned above, these data are *not* intended for training dialogue agents as this will likely lead to harmful model behavior.
|
| 17 |
+
|
| 18 |
+
Each of these datasets are described further below.
|
| 19 |
+
|
| 20 |
+
## Human preference data about helpfulness and harmlessness (PM Data)
|
| 21 |
+
|
| 22 |
+
The data are described in the paper: [Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback](https://arxiv.org/abs/2204.05862). If you find the data useful, please cite the paper. The data format is very simple -- each line of the jsonl files contains a pair of texts, one "chosen" and one "rejected".
|
| 23 |
+
|
| 24 |
+
For **helpfulness**, the data are grouped into train/test splits in three tranches: from our base models (context-distilled 52B language models), via rejection sampling (mostly with best-of-16 sampling) against an early preference model, and a dataset sampled during our iterated "online" process.
|
| 25 |
+
|
| 26 |
+
For **harmlessness**, the data are only collected for our base models, but otherwise formatted in the same way.
|
| 27 |
+
|
| 28 |
+
Details about the data collection process and crowdworker population can be found in the paper, specifically in section 2 and appendix D.
|
| 29 |
+
|
| 30 |
+
## Red teaming data (not PM Data)
|
| 31 |
+
|
| 32 |
+
The data are described in the paper: [Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned](https://www.anthropic.com/red_teaming.pdf). If you find the data useful, please cite the paper. At a high level these data can be examined to understand what succesful red team attempts look like.
|
| 33 |
+
|
| 34 |
+
Details about the data and data collection procedures can be found in the Datasheet in the appendix of the paper.
|
| 35 |
+
|
| 36 |
+
Each line of the jsonl file contains a dictionary with the following fields:
|
| 37 |
+
- `transcript` a text transcript of a conversation between a human adversary (red team member) and an AI assistant
|
| 38 |
+
- `min_harmlessness_score_transcript` a real value score of the harmlessness of the AI assistant (lower is more harmful) as obtained from a preference model
|
| 39 |
+
- `num_params` number of parameters in the language model powering the AI assistant
|
| 40 |
+
- `model_type` type of model powering the AI assistant
|
| 41 |
+
- `rating` the red team member's rating of how successful they were at breaking the AI assistant (Likert scale, higher is more successful)
|
| 42 |
+
- `task_description` a short text description written by the red team member about how they tried to red team the AI assistant
|
| 43 |
+
- `task_description_harmlessness_score` a real value score of the harmlessness of the task description (lower is more harmful) as obtained from a preference model
|
| 44 |
+
- `red_team_member_id` an arbitrary identifier of the red team member. one red team member can generate multiple red team attacks
|
| 45 |
+
- `is_upworker` a binary indicator that is true if the red team member was from the crowd platform Upwork or false if they were from MTurk
|
| 46 |
+
- `tags` a list of up to 6 tags per transcript. tags are short descriptions of the red team attempts generated by crowdworkers who reviewed red team data post-hoc. tags were only provided for a random sample of 1000 red team attempts for two of four model types.
|
| 47 |
+
|
| 48 |
+
## Usage
|
| 49 |
+
|
| 50 |
+
Each of the above datasets is located in a separate sub-directory. To load an individual subset, use the `data_dir` argument of the `load_dataset()` function as follows:
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
from datasets import load_dataset
|
| 54 |
+
|
| 55 |
+
# Load all helpfulness/harmless subsets (share the same schema)
|
| 56 |
+
dataset = load_dataset("Anthropic/hh-rlhf")
|
| 57 |
+
|
| 58 |
+
# Load one of the harmless subsets
|
| 59 |
+
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="harmless-base")
|
| 60 |
+
|
| 61 |
+
# Load the red teaming subset
|
| 62 |
+
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="red-team-attempts")
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Contact
|
| 66 |
+
|
| 67 |
+
The original authors host this dataset on GitHub here: https://github.com/anthropics/hh-rlhf
|
| 68 |
+
You can submit inquiries to: redteam@anthropic.com
|
datasets/hh-rlhf/harmless-base/test.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebfaaed21162a4de120ae85366075b34425e2a6303fac59b99555002f7016f03
|
| 3 |
+
size 742801
|
datasets/hh-rlhf/harmless-base/train.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf3e1ecd3db61bf21d190f69351d26670c67aababa0fd282fa008c17a7202217
|
| 3 |
+
size 13197306
|
datasets/hh-rlhf/helpful-base/test.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8be3fc1a13b27901631696f2be6f184c799f1baeaf145f53ac5db24960adc37b
|
| 3 |
+
size 875191
|
datasets/hh-rlhf/helpful-base/train.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:518a5bf288456fc9f3b7c980c54116fba0c52f274d3f4d344675d83e4058f6f4
|
| 3 |
+
size 16200131
|
datasets/hh-rlhf/helpful-online/test.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39a475cc7e0df972902f8e3a0972726fce826ac65369bda8c671de4330c60a10
|
| 3 |
+
size 1053578
|
datasets/hh-rlhf/helpful-online/train.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0994d18b82083f5c084ce04071b683958a724df8c7215ab1d851986642d36a2a
|
| 3 |
+
size 20127228
|
datasets/hh-rlhf/helpful-rejection-sampled/test.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:abc35dbc56015db1dae0faa484514ccbc3eee0133bfcd557fc69250619b028bb
|
| 3 |
+
size 1361231
|
datasets/hh-rlhf/helpful-rejection-sampled/train.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cea2d2ab2187624ec1314e8cc6fb579f0fa987b3a33e8102b7556ed4d255f64
|
| 3 |
+
size 25697148
|
datasets/hh-rlhf/red-team-attempts/red_team_attempts.jsonl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c7b0069991460f0064f279fd400b51f3f0095697d14d7793c49b0925f80814f
|
| 3 |
+
size 15483307
|
export_hf_checkpoint.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import transformers
|
| 5 |
+
from peft import PeftModel
|
| 6 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer # noqa: F402
|
| 7 |
+
|
| 8 |
+
# BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
| 9 |
+
# assert (
|
| 10 |
+
# BASE_MODEL
|
| 11 |
+
# ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
BASE_MODEL = '/home/hector5/models/Llama-3-Base-8B-SFT/'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# LORA_WEIGHT = '/home/hector5/RiC/ppo/ckpt_morlhf/morlhf_llamma3_lora-0.33_0.33_0.33/epoch_1/'
|
| 18 |
+
# LORA_WEIGHT = '/home/hector5/RiC/ppo/ckpt_morlhf/morlhf_llamma3_lora-0.2_0.2_0.6/epoch_1/'
|
| 19 |
+
LORA_WEIGHT = '/home/hector5/RiC/ppo/ckpt_morlhf/morlhf_llamma3_lora-0.0_0.0_1.0/epoch_0/'
|
| 20 |
+
|
| 21 |
+
OUTPUT_path = './merged_models/Llama-3-Base-8B-SFT-morlhf-humor'
|
| 22 |
+
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 24 |
+
tokenizer.save_pretrained(OUTPUT_path)
|
| 25 |
+
|
| 26 |
+
base_model = LlamaForCausalLM.from_pretrained(
|
| 27 |
+
BASE_MODEL,
|
| 28 |
+
load_in_8bit=False,
|
| 29 |
+
torch_dtype=torch.bfloat16,
|
| 30 |
+
device_map={"": "cpu"},
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
|
| 34 |
+
first_weight_old = first_weight.clone()
|
| 35 |
+
|
| 36 |
+
lora_model = PeftModel.from_pretrained(
|
| 37 |
+
base_model,
|
| 38 |
+
LORA_WEIGHT,
|
| 39 |
+
device_map={"": "cpu"},
|
| 40 |
+
torch_dtype=torch.bfloat16,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
lora_weight = lora_model.base_model.model.model.layers[
|
| 44 |
+
0
|
| 45 |
+
].self_attn.q_proj.weight
|
| 46 |
+
|
| 47 |
+
assert torch.allclose(first_weight_old, first_weight)
|
| 48 |
+
|
| 49 |
+
# merge weights - new merging method from peft
|
| 50 |
+
lora_model = lora_model.merge_and_unload()
|
| 51 |
+
|
| 52 |
+
lora_model.train(False)
|
| 53 |
+
|
| 54 |
+
# did we do anything?
|
| 55 |
+
assert not torch.allclose(first_weight_old, first_weight)
|
| 56 |
+
|
| 57 |
+
lora_model_sd = lora_model.state_dict()
|
| 58 |
+
deloreanized_sd = {
|
| 59 |
+
k.replace("base_model.model.", ""): v
|
| 60 |
+
for k, v in lora_model_sd.items()
|
| 61 |
+
if "lora" not in k
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
LlamaForCausalLM.save_pretrained(
|
| 65 |
+
base_model, OUTPUT_path, state_dict=deloreanized_sd, max_shard_size="2048MB"
|
| 66 |
+
)
|
fix_trl/utils.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import random
|
| 15 |
+
import warnings
|
| 16 |
+
from collections import deque
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from accelerate import PartialState
|
| 23 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 24 |
+
from torch.utils.data import IterableDataset
|
| 25 |
+
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
|
| 26 |
+
|
| 27 |
+
from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
|
| 28 |
+
from ..trainer.model_config import ModelConfig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_peft_available():
|
| 32 |
+
from peft import LoraConfig, PeftConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AdaptiveKLController:
|
| 36 |
+
"""
|
| 37 |
+
Adaptive KL controller described in the paper:
|
| 38 |
+
https://arxiv.org/pdf/1909.08593.pdf
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, init_kl_coef, target, horizon):
|
| 42 |
+
self.value = init_kl_coef
|
| 43 |
+
self.target = target
|
| 44 |
+
self.horizon = horizon
|
| 45 |
+
|
| 46 |
+
def update(self, current, n_steps):
|
| 47 |
+
target = self.target
|
| 48 |
+
proportional_error = np.clip(current / target - 1, -0.2, 0.2)
|
| 49 |
+
mult = 1 + proportional_error * n_steps / self.horizon
|
| 50 |
+
self.value *= mult
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FixedKLController:
|
| 54 |
+
"""Fixed KL controller."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, kl_coef):
|
| 57 |
+
self.value = kl_coef
|
| 58 |
+
|
| 59 |
+
def update(self, current, n_steps):
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
| 64 |
+
"""
|
| 65 |
+
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
|
| 66 |
+
when they do not come from the assistant. This ensure that the loss is only
|
| 67 |
+
calculated on the completion made by the assistant.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
|
| 71 |
+
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
|
| 72 |
+
differently if it does not have proper context.
|
| 73 |
+
instruction_template (`Union[str, List[int]]`): the template form that indicates the start of the human instruction, typically something like
|
| 74 |
+
'### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
|
| 75 |
+
mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
|
| 76 |
+
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
|
| 77 |
+
for flexibility and backwards-compatibility.
|
| 78 |
+
ignore_index (`int`, *optional*, defaults to `-100`):
|
| 79 |
+
The index to use to ignore the initial tokens with
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
response_template: Union[str, List[int]],
|
| 85 |
+
instruction_template: Optional[Union[str, List[int]]] = None,
|
| 86 |
+
*args,
|
| 87 |
+
mlm: bool = False,
|
| 88 |
+
ignore_index: int = -100,
|
| 89 |
+
**kwargs,
|
| 90 |
+
):
|
| 91 |
+
super().__init__(*args, mlm=mlm, **kwargs)
|
| 92 |
+
|
| 93 |
+
self.instruction_template = instruction_template
|
| 94 |
+
if isinstance(instruction_template, str):
|
| 95 |
+
# The user provides a string, must tokenize
|
| 96 |
+
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
|
| 97 |
+
else:
|
| 98 |
+
# The user already provides the token ids
|
| 99 |
+
self.instruction_token_ids = instruction_template
|
| 100 |
+
|
| 101 |
+
self.response_template = response_template
|
| 102 |
+
if isinstance(response_template, str):
|
| 103 |
+
# The user provides a string, must tokenize
|
| 104 |
+
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
|
| 105 |
+
else:
|
| 106 |
+
# The user already provides the token ids
|
| 107 |
+
self.response_token_ids = response_template
|
| 108 |
+
|
| 109 |
+
if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
| 110 |
+
warnings.warn(
|
| 111 |
+
"The pad_token_id and eos_token_id values of this tokenizer are identical. "
|
| 112 |
+
"If you are planning for multi-turn training, "
|
| 113 |
+
"it can result in the model continuously generating questions and answers without eos token. "
|
| 114 |
+
"To avoid this, set the pad_token_id to a different value."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.ignore_index = ignore_index
|
| 118 |
+
|
| 119 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 120 |
+
batch = super().torch_call(examples)
|
| 121 |
+
|
| 122 |
+
if self.instruction_template is None:
|
| 123 |
+
for i in range(len(examples)):
|
| 124 |
+
response_token_ids_start_idx = None
|
| 125 |
+
|
| 126 |
+
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
|
| 127 |
+
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
|
| 128 |
+
if (
|
| 129 |
+
self.response_token_ids
|
| 130 |
+
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
|
| 131 |
+
):
|
| 132 |
+
response_token_ids_start_idx = idx
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if response_token_ids_start_idx is None:
|
| 136 |
+
for idx in np.where(batch["labels"][i] == self.response_token_ids[1])[0]:
|
| 137 |
+
if (
|
| 138 |
+
self.response_token_ids[1:]
|
| 139 |
+
== batch["labels"][i][idx : idx + len(self.response_token_ids[1:])].tolist()
|
| 140 |
+
):
|
| 141 |
+
response_token_ids_start_idx = idx
|
| 142 |
+
|
| 143 |
+
if response_token_ids_start_idx is None:
|
| 144 |
+
warnings.warn(
|
| 145 |
+
f"Could not find response key `{self.response_template}` in the "
|
| 146 |
+
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
|
| 147 |
+
f"This instance will be ignored in loss calculation. "
|
| 148 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
| 149 |
+
)
|
| 150 |
+
batch["labels"][i, :] = self.ignore_index
|
| 151 |
+
else:
|
| 152 |
+
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids[1:])
|
| 153 |
+
else:
|
| 154 |
+
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
|
| 155 |
+
|
| 156 |
+
# Make pytorch loss function ignore all tokens up through the end of the response key
|
| 157 |
+
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
|
| 158 |
+
|
| 159 |
+
else:
|
| 160 |
+
for i in range(len(examples)):
|
| 161 |
+
response_token_ids_idxs = []
|
| 162 |
+
human_token_ids_idxs = []
|
| 163 |
+
|
| 164 |
+
for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
|
| 165 |
+
# find the indexes of the start of a response.
|
| 166 |
+
if (
|
| 167 |
+
self.response_token_ids
|
| 168 |
+
== batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
|
| 169 |
+
):
|
| 170 |
+
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
|
| 171 |
+
|
| 172 |
+
if len(response_token_ids_idxs) == 0:
|
| 173 |
+
warnings.warn(
|
| 174 |
+
f"Could not find response key `{self.response_template}` in the "
|
| 175 |
+
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
|
| 176 |
+
f"This instance will be ignored in loss calculation. "
|
| 177 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
| 178 |
+
)
|
| 179 |
+
batch["labels"][i, :] = self.ignore_index
|
| 180 |
+
|
| 181 |
+
human_token_ids = self.instruction_token_ids
|
| 182 |
+
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
|
| 183 |
+
# find the indexes of the start of a human answer.
|
| 184 |
+
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
|
| 185 |
+
human_token_ids_idxs.append(human_idx)
|
| 186 |
+
|
| 187 |
+
if len(human_token_ids_idxs) == 0:
|
| 188 |
+
warnings.warn(
|
| 189 |
+
f"Could not find instruction key `{self.instruction_template}` in the "
|
| 190 |
+
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
|
| 191 |
+
f"This instance will be ignored in loss calculation. "
|
| 192 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
| 193 |
+
)
|
| 194 |
+
batch["labels"][i, :] = self.ignore_index
|
| 195 |
+
|
| 196 |
+
if (
|
| 197 |
+
len(human_token_ids_idxs) > 0
|
| 198 |
+
and len(response_token_ids_idxs) > 0
|
| 199 |
+
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
|
| 200 |
+
):
|
| 201 |
+
human_token_ids_idxs = [0] + human_token_ids_idxs
|
| 202 |
+
|
| 203 |
+
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
|
| 204 |
+
# Make pytorch loss function ignore all non response tokens
|
| 205 |
+
if idx != 0:
|
| 206 |
+
batch["labels"][i, start:end] = self.ignore_index
|
| 207 |
+
else:
|
| 208 |
+
batch["labels"][i, :end] = self.ignore_index
|
| 209 |
+
|
| 210 |
+
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
|
| 211 |
+
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
|
| 212 |
+
|
| 213 |
+
return batch
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@dataclass
|
| 217 |
+
class RewardDataCollatorWithPadding:
|
| 218 |
+
r"""
|
| 219 |
+
Reward DataCollator class that pads the inputs to the maximum length of the batch.
|
| 220 |
+
Args:
|
| 221 |
+
tokenizer (`PreTrainedTokenizerBase`):
|
| 222 |
+
The tokenizer used for encoding the data.
|
| 223 |
+
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
|
| 224 |
+
padding_strategy to pass to the tokenizer.
|
| 225 |
+
max_length (`Optional[int]`, `optional`, defaults to `None`):
|
| 226 |
+
The maximum length of the sequence to be processed.
|
| 227 |
+
pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`):
|
| 228 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 229 |
+
return_tensors (`str`, `optional`, defaults to `"pt"`):
|
| 230 |
+
The tensor type to use.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
tokenizer: PreTrainedTokenizerBase
|
| 234 |
+
padding: Union[bool, str] = True
|
| 235 |
+
max_length: Optional[int] = None
|
| 236 |
+
pad_to_multiple_of: Optional[int] = None
|
| 237 |
+
return_tensors: str = "pt"
|
| 238 |
+
|
| 239 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 240 |
+
features_chosen = []
|
| 241 |
+
features_rejected = []
|
| 242 |
+
margin = []
|
| 243 |
+
# check if we have a margin. If we do, we need to batch it as well
|
| 244 |
+
has_margin = "margin" in features[0]
|
| 245 |
+
for feature in features:
|
| 246 |
+
# check if the keys are named as expected
|
| 247 |
+
if (
|
| 248 |
+
"input_ids_chosen" not in feature
|
| 249 |
+
or "input_ids_rejected" not in feature
|
| 250 |
+
or "attention_mask_chosen" not in feature
|
| 251 |
+
or "attention_mask_rejected" not in feature
|
| 252 |
+
):
|
| 253 |
+
raise ValueError(
|
| 254 |
+
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
features_chosen.append(
|
| 258 |
+
{
|
| 259 |
+
"input_ids": feature["input_ids_chosen"],
|
| 260 |
+
"attention_mask": feature["attention_mask_chosen"],
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
features_rejected.append(
|
| 264 |
+
{
|
| 265 |
+
"input_ids": feature["input_ids_rejected"],
|
| 266 |
+
"attention_mask": feature["attention_mask_rejected"],
|
| 267 |
+
}
|
| 268 |
+
)
|
| 269 |
+
if has_margin:
|
| 270 |
+
margin.append(feature["margin"])
|
| 271 |
+
batch_chosen = self.tokenizer.pad(
|
| 272 |
+
features_chosen,
|
| 273 |
+
padding=self.padding,
|
| 274 |
+
max_length=self.max_length,
|
| 275 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 276 |
+
return_tensors=self.return_tensors,
|
| 277 |
+
)
|
| 278 |
+
batch_rejected = self.tokenizer.pad(
|
| 279 |
+
features_rejected,
|
| 280 |
+
padding=self.padding,
|
| 281 |
+
max_length=self.max_length,
|
| 282 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 283 |
+
return_tensors=self.return_tensors,
|
| 284 |
+
)
|
| 285 |
+
batch = {
|
| 286 |
+
"input_ids_chosen": batch_chosen["input_ids"],
|
| 287 |
+
"attention_mask_chosen": batch_chosen["attention_mask"],
|
| 288 |
+
"input_ids_rejected": batch_rejected["input_ids"],
|
| 289 |
+
"attention_mask_rejected": batch_rejected["attention_mask"],
|
| 290 |
+
"return_loss": True,
|
| 291 |
+
}
|
| 292 |
+
if has_margin:
|
| 293 |
+
margin = torch.tensor(margin, dtype=torch.float)
|
| 294 |
+
batch["margin"] = margin
|
| 295 |
+
return batch
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@dataclass
|
| 299 |
+
class DPODataCollatorWithPadding:
|
| 300 |
+
r"""
|
| 301 |
+
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
|
| 302 |
+
Args:
|
| 303 |
+
pad_token_id (`int` defaults to 0):
|
| 304 |
+
The tokenizer's pad_token_id.
|
| 305 |
+
label_pad_token_id (`int`, defaults to -100):
|
| 306 |
+
The label used for masking.
|
| 307 |
+
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
| 308 |
+
Whether or not you model has an encoder_decoder architecture.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
pad_token_id: int = 0
|
| 312 |
+
label_pad_token_id: int = -100
|
| 313 |
+
is_encoder_decoder: Optional[bool] = False
|
| 314 |
+
|
| 315 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 316 |
+
# first, pad everything to the same length
|
| 317 |
+
padded_batch = {}
|
| 318 |
+
for k in features[0].keys():
|
| 319 |
+
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
|
| 320 |
+
if self.is_encoder_decoder:
|
| 321 |
+
to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
| 322 |
+
|
| 323 |
+
if (k.startswith("prompt")) and (k.endswith("input_ids")):
|
| 324 |
+
if self.pad_token_id is None:
|
| 325 |
+
raise ValueError(
|
| 326 |
+
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
| 327 |
+
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
|
| 328 |
+
" before calling the trainer."
|
| 329 |
+
)
|
| 330 |
+
padding_value = self.pad_token_id
|
| 331 |
+
elif k.endswith("_attention_mask"):
|
| 332 |
+
padding_value = 0
|
| 333 |
+
elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
|
| 334 |
+
padding_value = self.label_pad_token_id
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError(f"Unexpected key in batch '{k}'")
|
| 337 |
+
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
|
| 338 |
+
else:
|
| 339 |
+
# adapted from https://stackoverflow.com/questions/73256206
|
| 340 |
+
if "prompt" in k:
|
| 341 |
+
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features]
|
| 342 |
+
else:
|
| 343 |
+
to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
| 344 |
+
if k.endswith("_input_ids"):
|
| 345 |
+
if self.pad_token_id is None:
|
| 346 |
+
raise ValueError(
|
| 347 |
+
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
| 348 |
+
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
|
| 349 |
+
" before calling the trainer."
|
| 350 |
+
)
|
| 351 |
+
padding_value = self.pad_token_id
|
| 352 |
+
elif k.endswith("_labels"):
|
| 353 |
+
padding_value = self.label_pad_token_id
|
| 354 |
+
elif k.endswith("_attention_mask"):
|
| 355 |
+
padding_value = 0
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError(f"Unexpected key in batch '{k}'")
|
| 358 |
+
|
| 359 |
+
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
|
| 360 |
+
# for the prompt, flip back so padding is on left side
|
| 361 |
+
if "prompt" in k:
|
| 362 |
+
padded_batch[k] = padded_batch[k].flip(dims=[1])
|
| 363 |
+
elif k.endswith("_logps"):
|
| 364 |
+
# the cached reference model logprobs
|
| 365 |
+
padded_batch[k] = torch.tensor([ex[k] for ex in features])
|
| 366 |
+
else:
|
| 367 |
+
padded_batch[k] = [ex[k] for ex in features]
|
| 368 |
+
|
| 369 |
+
return padded_batch
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class ConstantLengthDataset(IterableDataset):
|
| 373 |
+
"""
|
| 374 |
+
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
| 375 |
+
The dataset also formats the text before tokenization with a specific format that is provided
|
| 376 |
+
by the user.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
tokenizer (`transformers.PreTrainedTokenizer`):
|
| 380 |
+
The processor used for processing the data.
|
| 381 |
+
dataset (`dataset.Dataset`):
|
| 382 |
+
Dataset with text files.
|
| 383 |
+
dataset_text_field (`str`, **optional**):
|
| 384 |
+
Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`.
|
| 385 |
+
formatting_func (`Callable`, **optional**):
|
| 386 |
+
Function that formats the text before tokenization. Usually it is recommended to have follows a certain
|
| 387 |
+
pattern such as `"### Question: {question} ### Answer: {answer}"`
|
| 388 |
+
infinite (`bool`, *optional*, defaults to `False`):
|
| 389 |
+
If True the iterator is reset after dataset reaches end else stops.
|
| 390 |
+
seq_length (`int`, *optional*, defaults to `1024`):
|
| 391 |
+
Length of token sequences to return.
|
| 392 |
+
num_of_sequences (`int`, *optional*, defaults to `1024`):
|
| 393 |
+
Number of token sequences to keep in buffer.
|
| 394 |
+
chars_per_token (`int`, *optional*, defaults to `3.6`):
|
| 395 |
+
Number of characters per token used to estimate number of tokens in text buffer.
|
| 396 |
+
eos_token_id (`int`, *optional*, defaults to `0`):
|
| 397 |
+
Id of the end of sequence token if the passed tokenizer does not have an EOS token.
|
| 398 |
+
shuffle ('bool', *optional*, defaults to True)
|
| 399 |
+
Shuffle the examples before they are returned
|
| 400 |
+
append_concat_token ('bool', *optional*, defaults to True)
|
| 401 |
+
If true, appends `eos_token_id` at the end of each sample being packed.
|
| 402 |
+
add_special_tokens ('bool', *optional*, defaults to True)
|
| 403 |
+
If true, tokenizers adds special tokens to each sample being packed.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
tokenizer,
|
| 409 |
+
dataset,
|
| 410 |
+
dataset_text_field=None,
|
| 411 |
+
formatting_func=None,
|
| 412 |
+
infinite=False,
|
| 413 |
+
seq_length=1024,
|
| 414 |
+
num_of_sequences=1024,
|
| 415 |
+
chars_per_token=3.6,
|
| 416 |
+
eos_token_id=0,
|
| 417 |
+
shuffle=True,
|
| 418 |
+
append_concat_token=True,
|
| 419 |
+
add_special_tokens=True,
|
| 420 |
+
):
|
| 421 |
+
self.tokenizer = tokenizer
|
| 422 |
+
|
| 423 |
+
if tokenizer.eos_token_id is None:
|
| 424 |
+
warnings.warn(
|
| 425 |
+
"The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds"
|
| 426 |
+
f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id."
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id
|
| 430 |
+
self.dataset = dataset
|
| 431 |
+
self.seq_length = seq_length
|
| 432 |
+
self.infinite = infinite
|
| 433 |
+
self.current_size = 0
|
| 434 |
+
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
| 435 |
+
self.shuffle = shuffle
|
| 436 |
+
self.append_concat_token = append_concat_token
|
| 437 |
+
self.add_special_tokens = add_special_tokens
|
| 438 |
+
if formatting_func is None:
|
| 439 |
+
self.formatting_func = lambda x: x[dataset_text_field]
|
| 440 |
+
else:
|
| 441 |
+
self.formatting_func = formatting_func
|
| 442 |
+
|
| 443 |
+
if formatting_func is not None:
|
| 444 |
+
if formatting_func.__code__.co_argcount > 1:
|
| 445 |
+
warnings.warn(
|
| 446 |
+
"The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
|
| 447 |
+
" which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def __len__(self):
|
| 451 |
+
return len(self.dataset)
|
| 452 |
+
|
| 453 |
+
def __iter__(self):
|
| 454 |
+
iterator = iter(self.dataset)
|
| 455 |
+
more_examples = True
|
| 456 |
+
while more_examples:
|
| 457 |
+
buffer, buffer_len = [], 0
|
| 458 |
+
while True:
|
| 459 |
+
if buffer_len >= self.max_buffer_size:
|
| 460 |
+
break
|
| 461 |
+
try:
|
| 462 |
+
buffer.append(self.formatting_func(next(iterator)))
|
| 463 |
+
buffer_len += len(buffer[-1])
|
| 464 |
+
except StopIteration:
|
| 465 |
+
if self.infinite:
|
| 466 |
+
iterator = iter(self.dataset)
|
| 467 |
+
warnings.warn("The dataset reached end and the iterator is reset to the start.")
|
| 468 |
+
else:
|
| 469 |
+
more_examples = False
|
| 470 |
+
break
|
| 471 |
+
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
|
| 472 |
+
"input_ids"
|
| 473 |
+
]
|
| 474 |
+
all_token_ids = []
|
| 475 |
+
for tokenized_input in tokenized_inputs:
|
| 476 |
+
if self.append_concat_token:
|
| 477 |
+
tokenized_input = tokenized_input + [self.concat_token_id]
|
| 478 |
+
all_token_ids.extend(tokenized_input)
|
| 479 |
+
examples = []
|
| 480 |
+
for i in range(0, len(all_token_ids), self.seq_length):
|
| 481 |
+
input_ids = all_token_ids[i : i + self.seq_length]
|
| 482 |
+
if len(input_ids) == self.seq_length:
|
| 483 |
+
examples.append(input_ids)
|
| 484 |
+
if self.shuffle:
|
| 485 |
+
random.shuffle(examples)
|
| 486 |
+
for example in examples:
|
| 487 |
+
self.current_size += 1
|
| 488 |
+
yield {
|
| 489 |
+
"input_ids": torch.LongTensor(example),
|
| 490 |
+
"labels": torch.LongTensor(example),
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class RunningMoments:
|
| 495 |
+
def __init__(self, accelerator):
|
| 496 |
+
"""
|
| 497 |
+
Calculates the running mean and standard deviation of a data stream. Reference:
|
| 498 |
+
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
|
| 499 |
+
"""
|
| 500 |
+
self.mean = 0
|
| 501 |
+
self.std = 1
|
| 502 |
+
self.var = 1
|
| 503 |
+
self.count = 1e-24
|
| 504 |
+
self.accelerator = accelerator
|
| 505 |
+
|
| 506 |
+
@torch.no_grad()
|
| 507 |
+
def update(self, xs: torch.Tensor) -> Tuple[float, float]:
|
| 508 |
+
"""
|
| 509 |
+
Updates running moments from batch's moments computed across ranks
|
| 510 |
+
"""
|
| 511 |
+
if self.accelerator.use_distributed:
|
| 512 |
+
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
|
| 513 |
+
else:
|
| 514 |
+
xs_count = xs.numel()
|
| 515 |
+
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
|
| 516 |
+
xs_mean, xs_var = xs_mean.float(), xs_var.float()
|
| 517 |
+
|
| 518 |
+
delta = xs_mean - self.mean
|
| 519 |
+
tot_count = self.count + xs_count
|
| 520 |
+
|
| 521 |
+
new_sum = xs_var * xs_count
|
| 522 |
+
# correct old_sum deviation accounting for the new mean
|
| 523 |
+
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
|
| 524 |
+
tot_sum = old_sum + new_sum
|
| 525 |
+
|
| 526 |
+
self.mean += delta * xs_count / tot_count
|
| 527 |
+
self.var = tot_sum / tot_count
|
| 528 |
+
self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
|
| 529 |
+
self.count = tot_count
|
| 530 |
+
|
| 531 |
+
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@torch.no_grad()
|
| 535 |
+
def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
|
| 536 |
+
"""
|
| 537 |
+
Computes element-wise mean and variance of the tensor across processes. Reference:
|
| 538 |
+
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
|
| 539 |
+
"""
|
| 540 |
+
xs = xs.to(accelerator.device)
|
| 541 |
+
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
|
| 542 |
+
sum_and_count = accelerator.reduce(sum_and_count)
|
| 543 |
+
global_sum, count = sum_and_count
|
| 544 |
+
global_mean = global_sum / count
|
| 545 |
+
|
| 546 |
+
sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
|
| 547 |
+
sum_var = accelerator.reduce(sum_var)
|
| 548 |
+
global_var = sum_var / count
|
| 549 |
+
|
| 550 |
+
return global_mean.to(device), global_var.to(device), count.to(device)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def compute_accuracy(eval_pred) -> Dict[str, float]:
|
| 554 |
+
predictions, labels = eval_pred
|
| 555 |
+
# Here, predictions is rewards_chosen and rewards_rejected.
|
| 556 |
+
# We want to see how much of the time rewards_chosen > rewards_rejected.
|
| 557 |
+
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
|
| 558 |
+
warnings.warn(
|
| 559 |
+
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
|
| 560 |
+
)
|
| 561 |
+
predictions = np.argmax(predictions, axis=1)
|
| 562 |
+
|
| 563 |
+
accuracy = np.array(predictions == labels, dtype=float).mean().item()
|
| 564 |
+
return {"accuracy": accuracy}
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
|
| 568 |
+
if tensor.size(dim) >= length:
|
| 569 |
+
return tensor
|
| 570 |
+
else:
|
| 571 |
+
pad_size = list(tensor.shape)
|
| 572 |
+
pad_size[dim] = length - tensor.size(dim)
|
| 573 |
+
return torch.cat(
|
| 574 |
+
[
|
| 575 |
+
tensor,
|
| 576 |
+
pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
|
| 577 |
+
],
|
| 578 |
+
dim=dim,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
| 583 |
+
for module in model.modules():
|
| 584 |
+
if isinstance(module, torch.nn.Dropout):
|
| 585 |
+
module.p = 0
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def exact_div(a, b, a_str, b_str, custom_error_message=""):
|
| 589 |
+
q = a // b
|
| 590 |
+
if a != q * b:
|
| 591 |
+
raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}")
|
| 592 |
+
return q
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5
|
| 596 |
+
class PerPromptStatTracker:
|
| 597 |
+
r"""
|
| 598 |
+
Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
buffer_size (`int`):
|
| 602 |
+
Size of the buffer to keep for each prompt.
|
| 603 |
+
min_count (`int`):
|
| 604 |
+
Minimum number of samples to keep in the buffer before calculating the mean and std.
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
def __init__(self, buffer_size, min_count):
|
| 608 |
+
self.buffer_size = buffer_size
|
| 609 |
+
self.min_count = min_count
|
| 610 |
+
self.stats = {}
|
| 611 |
+
|
| 612 |
+
def update(self, prompts, rewards):
|
| 613 |
+
prompts = np.array(prompts)
|
| 614 |
+
rewards = np.array(rewards)
|
| 615 |
+
unique = np.unique(prompts)
|
| 616 |
+
advantages = np.empty_like(rewards)
|
| 617 |
+
for prompt in unique:
|
| 618 |
+
prompt_rewards = rewards[prompts == prompt]
|
| 619 |
+
if prompt not in self.stats:
|
| 620 |
+
self.stats[prompt] = deque(maxlen=self.buffer_size)
|
| 621 |
+
self.stats[prompt].extend(prompt_rewards)
|
| 622 |
+
|
| 623 |
+
if len(self.stats[prompt]) < self.min_count:
|
| 624 |
+
mean = np.mean(rewards)
|
| 625 |
+
std = np.std(rewards) + 1e-6
|
| 626 |
+
else:
|
| 627 |
+
mean = np.mean(self.stats[prompt])
|
| 628 |
+
std = np.std(self.stats[prompt]) + 1e-6
|
| 629 |
+
advantages[prompts == prompt] = (prompt_rewards - mean) / std
|
| 630 |
+
|
| 631 |
+
return advantages
|
| 632 |
+
|
| 633 |
+
def get_stats(self):
|
| 634 |
+
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def neftune_post_forward_hook(module, input, output):
|
| 638 |
+
"""
|
| 639 |
+
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
|
| 640 |
+
torch.nn.Embedding layers. This method is slightly adapted from the original source code
|
| 641 |
+
that can be found here: https://github.com/neelsjain/NEFTune
|
| 642 |
+
|
| 643 |
+
Simply add it to your model as follows:
|
| 644 |
+
```python
|
| 645 |
+
model = ...
|
| 646 |
+
model.embed_tokens.neftune_noise_alpha = 0.1
|
| 647 |
+
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
|
| 648 |
+
```
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
module (`torch.nn.Module`):
|
| 652 |
+
The embedding module where the hook is attached. Note that you need to set
|
| 653 |
+
`module.neftune_noise_alpha` to the desired noise alpha value.
|
| 654 |
+
input (`torch.Tensor`):
|
| 655 |
+
The input tensor to the model.
|
| 656 |
+
output (`torch.Tensor`):
|
| 657 |
+
The output tensor of the model (i.e. the embeddings).
|
| 658 |
+
"""
|
| 659 |
+
if module.training:
|
| 660 |
+
dims = torch.tensor(output.size(1) * output.size(2))
|
| 661 |
+
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
|
| 662 |
+
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
| 663 |
+
return output
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def peft_module_casting_to_bf16(model):
|
| 667 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
| 668 |
+
|
| 669 |
+
for name, module in model.named_modules():
|
| 670 |
+
if isinstance(module, BaseTunerLayer):
|
| 671 |
+
module = module.to(torch.bfloat16)
|
| 672 |
+
elif isinstance(module, torch.nn.LayerNorm) or "norm" in name:
|
| 673 |
+
module = module.to(torch.float32)
|
| 674 |
+
elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
| 675 |
+
if hasattr(module, "weight"):
|
| 676 |
+
if module.weight.dtype == torch.float32:
|
| 677 |
+
module = module.to(torch.bfloat16)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
|
| 681 |
+
if is_unsloth_available():
|
| 682 |
+
# Unsloth adds a new attribute in the model config `unsloth_version`
|
| 683 |
+
# to keep track of models that have been patched with unsloth.
|
| 684 |
+
if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None:
|
| 685 |
+
tag_names.append("unsloth")
|
| 686 |
+
|
| 687 |
+
if kwargs is not None:
|
| 688 |
+
if "tags" not in kwargs:
|
| 689 |
+
kwargs["tags"] = tag_names
|
| 690 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
| 691 |
+
kwargs["tags"].extend(tag_names)
|
| 692 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
| 693 |
+
tag_names.append(kwargs["tags"])
|
| 694 |
+
kwargs["tags"] = tag_names
|
| 695 |
+
return kwargs
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]:
|
| 699 |
+
if model_config.load_in_4bit:
|
| 700 |
+
quantization_config = BitsAndBytesConfig(
|
| 701 |
+
load_in_4bit=True,
|
| 702 |
+
bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
|
| 703 |
+
bnb_4bit_quant_type=model_config.bnb_4bit_quant_type,
|
| 704 |
+
bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant,
|
| 705 |
+
)
|
| 706 |
+
elif model_config.load_in_8bit:
|
| 707 |
+
quantization_config = BitsAndBytesConfig(
|
| 708 |
+
load_in_8bit=True,
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
quantization_config = None
|
| 712 |
+
|
| 713 |
+
return quantization_config
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def get_kbit_device_map() -> Optional[Dict[str, int]]:
|
| 717 |
+
if is_xpu_available():
|
| 718 |
+
return {"": f"xpu:{PartialState().local_process_index}"}
|
| 719 |
+
elif torch.cuda.is_available():
|
| 720 |
+
return {"": PartialState().local_process_index}
|
| 721 |
+
else:
|
| 722 |
+
return None
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]":
|
| 726 |
+
if model_config.use_peft is False:
|
| 727 |
+
return None
|
| 728 |
+
|
| 729 |
+
peft_config = LoraConfig(
|
| 730 |
+
r=model_config.lora_r,
|
| 731 |
+
lora_alpha=model_config.lora_alpha,
|
| 732 |
+
lora_dropout=model_config.lora_dropout,
|
| 733 |
+
bias="none",
|
| 734 |
+
task_type="CAUSAL_LM",
|
| 735 |
+
target_modules=model_config.lora_target_modules,
|
| 736 |
+
modules_to_save=model_config.lora_modules_to_save,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
return peft_config
|
plot.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import os
|
| 5 |
+
import glob2
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def find_pareto_points(obtained_scores, threshold=0.02):
|
| 10 |
+
n = len(obtained_scores)
|
| 11 |
+
if n == 1:
|
| 12 |
+
return obtained_scores
|
| 13 |
+
pareto_index = []
|
| 14 |
+
high_low = np.max(obtained_scores, axis=0) - np.min(obtained_scores, axis=0)
|
| 15 |
+
for i in range(n):
|
| 16 |
+
if not any(np.all((obtained_scores - obtained_scores[i] - threshold * high_low) > 0.0, axis=1)):
|
| 17 |
+
pareto_index.append(i)
|
| 18 |
+
|
| 19 |
+
points = obtained_scores[np.array(pareto_index)]
|
| 20 |
+
arg_index = np.argsort(points[:, 0])
|
| 21 |
+
points = points[arg_index]
|
| 22 |
+
print(points)
|
| 23 |
+
sorted_index = [0]
|
| 24 |
+
remaining_index = np.ones(len(points))
|
| 25 |
+
i = 0
|
| 26 |
+
remaining_index[i] = 0
|
| 27 |
+
while sum(remaining_index):
|
| 28 |
+
distance = ((points[np.where(remaining_index)] - points[i]) ** 2 ).sum(axis=1)
|
| 29 |
+
min_index = np.where(remaining_index > 0)[0][np.argmin(distance)]
|
| 30 |
+
sorted_index.append(min_index)
|
| 31 |
+
i = min_index
|
| 32 |
+
remaining_index[i] = 0
|
| 33 |
+
return points[np.array(sorted_index)]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
index = 1
|
| 38 |
+
colors = sns.color_palette('Paired')
|
| 39 |
+
def plot_points(dir, label, style='-*', color='b', shift=[0,0], txt_color='black', normalize_path=None, reverse=True):
|
| 40 |
+
threshold = 0.01
|
| 41 |
+
desired_scores = []
|
| 42 |
+
obtained_scores = []
|
| 43 |
+
|
| 44 |
+
paths = [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*.csv'))]
|
| 45 |
+
paths += [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*', '*.csv'))]
|
| 46 |
+
|
| 47 |
+
pref_lis = []
|
| 48 |
+
for path in paths:
|
| 49 |
+
if '.csv' in path:
|
| 50 |
+
full_path = path
|
| 51 |
+
data = pd.read_csv(full_path)
|
| 52 |
+
# morlhf has less points, let the threshold larger to make the frontier better
|
| 53 |
+
if 'ppo' in path and len(paths) <= 5:
|
| 54 |
+
threshold = 0.5
|
| 55 |
+
obtained_scores.append([np.mean(data['obtained_score1']), np.mean(data['obtained_score2'])])
|
| 56 |
+
if 'pref' in path:
|
| 57 |
+
# get the preference
|
| 58 |
+
if 'eval_data_pref' in path:
|
| 59 |
+
pref = path.split('eval_data_pref')[-1].strip().split('_')[0]
|
| 60 |
+
pref_lis.append(float(pref))
|
| 61 |
+
|
| 62 |
+
print(pref_lis)
|
| 63 |
+
desired_scores = np.array(desired_scores)
|
| 64 |
+
obtained_scores = np.array(obtained_scores)
|
| 65 |
+
|
| 66 |
+
if normalize_path is not None:
|
| 67 |
+
norm_info = np.load(normalize_path)
|
| 68 |
+
norm_info = np.array(norm_info).reshape(2, 2)
|
| 69 |
+
for i in range(2):
|
| 70 |
+
obtained_scores[:, i] = (obtained_scores[:, i] - norm_info[i][0]) / norm_info[i][1]
|
| 71 |
+
|
| 72 |
+
global index
|
| 73 |
+
markersize = 10 if ('*' in style or 'o' in style) else 9
|
| 74 |
+
pareto_points = find_pareto_points(obtained_scores, threshold)
|
| 75 |
+
plt.scatter(obtained_scores[:, 0], obtained_scores[:, 1], marker=style[-1], color=colors[index], s=markersize + 60)
|
| 76 |
+
if len(pref_lis):
|
| 77 |
+
for i in range(len(obtained_scores)):
|
| 78 |
+
plt.annotate('{}'.format(round(pref_lis[i], 1)), (obtained_scores[i, 0] + shift[0], obtained_scores[i, 1] + shift[1]), size=4, color=txt_color)
|
| 79 |
+
|
| 80 |
+
plt.plot(pareto_points[:, 0], pareto_points[:, 1], style, c=colors[index], markersize=markersize, label=label)
|
| 81 |
+
index += 2
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
plt.figure(figsize=(5, 4))
|
| 85 |
+
|
| 86 |
+
name1 = 'harmless'
|
| 87 |
+
name2 = 'helpful'
|
| 88 |
+
|
| 89 |
+
### replace the paths to your own paths
|
| 90 |
+
plot_points('./logs_trl/eval_pretrained', 'Llama 2 base', '*')
|
| 91 |
+
plot_points('./logs_trl/eval_sft_alldata', 'SFT', '*')
|
| 92 |
+
plot_points('./eval_ppo_pref/', 'MORLHF', '--D', shift=[-0.012, -0.022])
|
| 93 |
+
plot_points('./logs_ppo/eval_pposoups_llamma2_klreg0.2', 'Rewarded Soups', style='--s', shift=[-0.012, -0.022])
|
| 94 |
+
plot_points('.logs_trl/evalnew_onlinefix_helpful_harmlesshelpful_iter2', 'RiC', style='-o', shift=[-0.012, -0.022], txt_color='white')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
plt.xlabel('$R_1$ ({})'.format(name1), fontsize=12)
|
| 98 |
+
plt.ylabel('$R_2$ ({})'.format(name2), fontsize=12)
|
| 99 |
+
plt.legend(fontsize=11, loc='lower left')
|
| 100 |
+
plt.tight_layout()
|
| 101 |
+
plt.savefig('ric_assistant_{}_{}.pdf'.format(name1, name2))
|
| 102 |
+
|
| 103 |
+
|
ppo/__pycache__/multi_reward_models.cpython-310.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
ppo/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
ppo/eval_ppo_single_model.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import AutoModelForCausalLM, HfArgumentParser
|
| 8 |
+
from transformers import DataCollatorWithPadding
|
| 9 |
+
from trl import set_seed
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from utils import load_main_tokenizer
|
| 14 |
+
from multi_reward_models import RewardModels
|
| 15 |
+
from utils import Instructions, Instructions_summary, build_dataset_eval, build_dataset_summary_eval, check_lora_in_model_path, merge_lora_weight
|
| 16 |
+
tqdm.pandas()
|
| 17 |
+
|
| 18 |
+
# define paths for two datasets
|
| 19 |
+
hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
|
| 20 |
+
summary_dataset_path = 'openai/summarize_from_feedback'
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ScriptArguments:
|
| 24 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 25 |
+
save_directory: Optional[str] = field(default='./logs_ppo_assistant')
|
| 26 |
+
batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
|
| 27 |
+
wandb_name: Optional[str] = field(default='eval_morlhf_kl0.2_multirew0.2_gen128_harmless_helpful', metadata={"help": "Name for this experiment"})
|
| 28 |
+
reward_names:Optional[str] = field(default='harmless,helpful')
|
| 29 |
+
base_model_name: Optional[str] = field(default='./logs_ppo_prefnew/ppo_llamma2_klreg0.2_harmless_helpful_preference0.5_0.5/batch_207')
|
| 30 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
|
| 31 |
+
|
| 32 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 33 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 34 |
+
exp_type = script_args.exp_type
|
| 35 |
+
base_model_name = script_args.base_model_name
|
| 36 |
+
tokenier_name = script_args.base_model_name
|
| 37 |
+
print('base model: ', base_model_name)
|
| 38 |
+
|
| 39 |
+
reward_names = [x.strip() for x in script_args.reward_names.split(',')]
|
| 40 |
+
num_rewards = len(reward_names)
|
| 41 |
+
print('number of rewards: {}'.format(num_rewards))
|
| 42 |
+
reward_path_tokenizer_dict = {
|
| 43 |
+
'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
|
| 44 |
+
'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
|
| 45 |
+
'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
|
| 46 |
+
'summary': ['Tristan/gpt2_reward_summarization'],
|
| 47 |
+
'faithful':['CogComp/bart-faithful-summary-detector'],
|
| 48 |
+
'humor': ['mohameddhiab/humor-no-humor'],
|
| 49 |
+
}
|
| 50 |
+
reward_model_path_list = []
|
| 51 |
+
rm_tokenizer_path_list = []
|
| 52 |
+
for name in reward_names:
|
| 53 |
+
if name not in reward_path_tokenizer_dict.keys():
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 56 |
+
rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 57 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
process_id = Accelerator().local_process_index
|
| 61 |
+
gpu_id = process_id
|
| 62 |
+
print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
##############
|
| 66 |
+
rm_gpu_id = gpu_id
|
| 67 |
+
reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, rm_gpu_id)
|
| 68 |
+
|
| 69 |
+
set_seed(8888)
|
| 70 |
+
current_device = Accelerator().local_process_index
|
| 71 |
+
print(current_device)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 75 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 76 |
+
base_model_name,
|
| 77 |
+
torch_dtype=torch.bfloat16, # fast inference
|
| 78 |
+
device_map=gpu_id,
|
| 79 |
+
)
|
| 80 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 81 |
+
model = merge_lora_weight(model, base_model_name)
|
| 82 |
+
|
| 83 |
+
generation_kwargs = {
|
| 84 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 85 |
+
"min_length": -1,
|
| 86 |
+
"top_k": 0.0,
|
| 87 |
+
"top_p": 0.9,
|
| 88 |
+
"do_sample": True,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_clean_data(full_responses, full_prompts):
|
| 93 |
+
full_responses_clean = []
|
| 94 |
+
for i, response in enumerate(full_responses):
|
| 95 |
+
full_prompts[i] = full_prompts[i].strip('[PAD] ')
|
| 96 |
+
response = response.strip('[PAD] ')
|
| 97 |
+
temp_resp = response.replace(full_prompts[i], '').strip('<s>').strip('</s>')
|
| 98 |
+
if '</s>' in temp_resp:
|
| 99 |
+
temp_resp = temp_resp[:temp_resp.rindex('</s>')]
|
| 100 |
+
temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 101 |
+
temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 102 |
+
temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 103 |
+
temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 104 |
+
temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 105 |
+
full_responses_clean.append(full_prompts[i] + ' ' + temp_resp)
|
| 106 |
+
return full_responses_clean
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
### for evaluation
|
| 110 |
+
print('evaluation........')
|
| 111 |
+
tokenizer.padding_side = "left"
|
| 112 |
+
if exp_type == 'assistant':
|
| 113 |
+
valid_dataset = build_dataset_eval(hhrlhf_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
|
| 114 |
+
instructions = Instructions()
|
| 115 |
+
else:
|
| 116 |
+
valid_dataset = build_dataset_summary_eval(summary_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
|
| 117 |
+
instructions = Instructions_summary()
|
| 118 |
+
print(f"Size of the validation set: {len(valid_dataset)}")
|
| 119 |
+
|
| 120 |
+
valid_batch_size = 1
|
| 121 |
+
valid_dataset = valid_dataset.remove_columns(['prompt', 'query'])
|
| 122 |
+
for key in ['key', 'text', 'response']:
|
| 123 |
+
if key in valid_dataset.column_names:
|
| 124 |
+
valid_dataset = valid_dataset.remove_columns(key)
|
| 125 |
+
|
| 126 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 127 |
+
valid_data_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, drop_last=True, collate_fn=data_collator)
|
| 128 |
+
accelerator = Accelerator()
|
| 129 |
+
model, valid_data_loader = accelerator.prepare(model, valid_data_loader)
|
| 130 |
+
|
| 131 |
+
full_response_tensors = []
|
| 132 |
+
full_prompts = []
|
| 133 |
+
|
| 134 |
+
pbar = tqdm(total=len(valid_dataset) // valid_batch_size // accelerator.num_processes)
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
for i, batch in enumerate(valid_data_loader):
|
| 137 |
+
response_tensors = accelerator.unwrap_model(model).generate(batch['input_ids'], attention_mask=batch['attention_mask'], **generation_kwargs)
|
| 138 |
+
full_response_tensors.extend(response_tensors)
|
| 139 |
+
full_prompts.extend(batch['input_ids'])
|
| 140 |
+
pbar.update(1)
|
| 141 |
+
|
| 142 |
+
full_prompts = tokenizer.batch_decode(full_prompts)
|
| 143 |
+
full_responses = tokenizer.batch_decode(full_response_tensors)
|
| 144 |
+
full_responses = get_clean_data(full_responses, full_prompts)
|
| 145 |
+
# Compute score
|
| 146 |
+
queries_responses = [
|
| 147 |
+
(instructions.get_input(text), instructions.get_response(text))
|
| 148 |
+
for text in full_responses
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if hasattr(instructions, 'get_post'):
|
| 153 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
|
| 154 |
+
else:
|
| 155 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses)
|
| 156 |
+
|
| 157 |
+
all_rewards = []
|
| 158 |
+
for i in range(reward_models.num_rewards):
|
| 159 |
+
all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
|
| 160 |
+
### merge data
|
| 161 |
+
### error here may because of old version of transformers/accelerate/peft
|
| 162 |
+
all_full_prompts = accelerator.gather_for_metrics(full_prompts)
|
| 163 |
+
all_full_responses = accelerator.gather_for_metrics(full_responses)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if process_id == 0:
|
| 167 |
+
evaluation_result = {
|
| 168 |
+
'prompt':all_full_prompts,
|
| 169 |
+
'response': all_full_responses,
|
| 170 |
+
}
|
| 171 |
+
for i in range(reward_models.num_rewards):
|
| 172 |
+
evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
|
| 173 |
+
print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
|
| 174 |
+
|
| 175 |
+
dataframe = pd.DataFrame(evaluation_result)
|
| 176 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data.csv'))
|
| 177 |
+
|
ppo/eval_rewarded_soups.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from accelerate import Accelerator
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from transformers import HfArgumentParser
|
| 10 |
+
from transformers import AutoModelForCausalLM, DataCollatorWithPadding
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from trl import set_seed
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from utils import get_clean_data, load_main_tokenizer, save_configs, print_trainable_parameters, \
|
| 16 |
+
merge_weights_with_preference, Instructions, Instructions_summary, build_dataset_eval, build_dataset_summary_eval
|
| 17 |
+
from multi_reward_models import RewardModels
|
| 18 |
+
tqdm.pandas()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# define paths for two datasets
|
| 22 |
+
hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
|
| 23 |
+
summary_dataset_path = 'openai/summarize_from_feedback'
|
| 24 |
+
tokenizer_path = 'meta-llama/Llama-2-7b-hf'
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ScriptArguments:
|
| 28 |
+
save_directory: Optional[str] = field(default='./logs_rewardedsoups_summary_eval')
|
| 29 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "minibatch size for eval"})
|
| 30 |
+
wandb_name: Optional[str] = field(default='eval_pposoups_klreg0.2_harmless_helpful', metadata={"help": "Name for this experiment"})
|
| 31 |
+
reward_names:Optional[str] = field(default='harmless,helpful')
|
| 32 |
+
base_model_path1: Optional[str]=field(default='./ppo_llamma2_klreg0.2_harmless/batch_200')
|
| 33 |
+
base_model_path2: Optional[str]=field(default='./ppo_llamma2_klreg0.2_helpful/batch_200')
|
| 34 |
+
base_model_path3: Optional[str]=field(default='')
|
| 35 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
|
| 36 |
+
|
| 37 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 38 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 39 |
+
exp_type = script_args.exp_type
|
| 40 |
+
base_model_name_1 = script_args.base_model_path1
|
| 41 |
+
base_model_name_2 = script_args.base_model_path2
|
| 42 |
+
base_model_name_3 = script_args.base_model_path3
|
| 43 |
+
tokenier_name = tokenizer_path
|
| 44 |
+
|
| 45 |
+
reward_names = [x.strip() for x in script_args.reward_names.split(',')]
|
| 46 |
+
reward_path_tokenizer_dict = {
|
| 47 |
+
'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
|
| 48 |
+
'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
|
| 49 |
+
'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
|
| 50 |
+
'summary': ['Tristan/gpt2_reward_summarization'],
|
| 51 |
+
'faithful':['CogComp/bart-faithful-summary-detector'],
|
| 52 |
+
'humor': ['mohameddhiab/humor-no-humor'],
|
| 53 |
+
}
|
| 54 |
+
reward_model_path_list = []
|
| 55 |
+
rm_tokenizer_path_list = []
|
| 56 |
+
for name in reward_names:
|
| 57 |
+
if name not in reward_path_tokenizer_dict.keys():
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 60 |
+
rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 61 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 62 |
+
save_info = {
|
| 63 |
+
'base_model_name_1': base_model_name_1,
|
| 64 |
+
'base_model_name_2': base_model_name_2,
|
| 65 |
+
'base_model_name_3': base_model_name_3,
|
| 66 |
+
'reward_peft_path1': reward_model_path_list[0],
|
| 67 |
+
'reward_peft_path2': reward_model_path_list[1],
|
| 68 |
+
'tokenier_name': tokenier_name
|
| 69 |
+
}
|
| 70 |
+
for i in range(len(reward_model_path_list)):
|
| 71 |
+
save_info['reward_peft_path{}'.format(i+1)] = reward_model_path_list[i]
|
| 72 |
+
save_configs(save_info, os.path.join(script_args.save_directory, script_args.wandb_name))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
accelerator = Accelerator()
|
| 76 |
+
process_id = Accelerator().local_process_index
|
| 77 |
+
gpu_id = process_id
|
| 78 |
+
print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
|
| 79 |
+
reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
set_seed(8888)
|
| 83 |
+
current_device = Accelerator().local_process_index
|
| 84 |
+
print(current_device)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 88 |
+
if exp_type == 'assistant':
|
| 89 |
+
valid_dataset = build_dataset_eval(hhrlhf_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
|
| 90 |
+
instructions = Instructions()
|
| 91 |
+
else:
|
| 92 |
+
valid_dataset = build_dataset_summary_eval(summary_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
|
| 93 |
+
instructions = Instructions_summary()
|
| 94 |
+
print(f"Size of the validation set: {len(valid_dataset)}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
valid_dataset = valid_dataset.remove_columns(['prompt', 'query'])
|
| 98 |
+
for key in ['key', 'text', 'response']:
|
| 99 |
+
if key in valid_dataset.column_names:
|
| 100 |
+
valid_dataset = valid_dataset.remove_columns(key)
|
| 101 |
+
|
| 102 |
+
def evaluate_model(temp_save_path, tokenizer, valid_dataset):
|
| 103 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 104 |
+
valid_data_loader = DataLoader(valid_dataset, batch_size=script_args.mini_batch_size, drop_last=True, collate_fn=data_collator)
|
| 105 |
+
### load_merged_model
|
| 106 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 107 |
+
temp_save_path,
|
| 108 |
+
torch_dtype=torch.bfloat16, # for merging model
|
| 109 |
+
device_map=gpu_id
|
| 110 |
+
)
|
| 111 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 112 |
+
|
| 113 |
+
accelerator = Accelerator()
|
| 114 |
+
model, valid_data_loader = accelerator.prepare(model, valid_data_loader)
|
| 115 |
+
|
| 116 |
+
generation_kwargs = {
|
| 117 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 118 |
+
"min_length": -1,
|
| 119 |
+
"top_k": 0.0,
|
| 120 |
+
"top_p": 0.9,
|
| 121 |
+
"do_sample": True,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
full_responses = []
|
| 125 |
+
full_prompts = []
|
| 126 |
+
pbar = tqdm(total=len(valid_dataset) // script_args.mini_batch_size // accelerator.num_processes)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
for i, batch in enumerate(valid_data_loader):
|
| 129 |
+
response_tensors = accelerator.unwrap_model(model).generate(batch["input_ids"], **generation_kwargs) #length_sampler=output_length_sampler,
|
| 130 |
+
full_responses.extend(response_tensors)
|
| 131 |
+
full_prompts.extend(batch['input_ids'])
|
| 132 |
+
pbar.update(1)
|
| 133 |
+
|
| 134 |
+
full_responses = tokenizer.batch_decode(full_responses)
|
| 135 |
+
full_prompts = tokenizer.batch_decode(full_prompts)
|
| 136 |
+
# clean data
|
| 137 |
+
full_prompts, full_responses = get_clean_data(full_responses, full_prompts)
|
| 138 |
+
|
| 139 |
+
queries_responses = [
|
| 140 |
+
(instructions.get_input(text), instructions.get_response(text))
|
| 141 |
+
for text in full_responses
|
| 142 |
+
]
|
| 143 |
+
if hasattr(instructions, 'get_post'):
|
| 144 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
|
| 145 |
+
else:
|
| 146 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses)
|
| 147 |
+
|
| 148 |
+
### error here may because of old version of transformers/accelerate/peft
|
| 149 |
+
all_rewards = []
|
| 150 |
+
for i in range(reward_models.num_rewards):
|
| 151 |
+
all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
|
| 152 |
+
all_full_prompts = accelerator.gather_for_metrics(full_prompts)
|
| 153 |
+
all_full_responses = accelerator.gather_for_metrics(full_responses)
|
| 154 |
+
return all_rewards, all_full_prompts, all_full_responses
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
print("Evaluating........")
|
| 158 |
+
tokenizer.padding_side = "left"
|
| 159 |
+
## preference list
|
| 160 |
+
if reward_models.num_rewards == 3:
|
| 161 |
+
preferences = np.array([
|
| 162 |
+
[0.0, 0.0, 1.0],
|
| 163 |
+
[0.0, 1.0, 0.0],
|
| 164 |
+
[0.1, 0.1, 0.8],
|
| 165 |
+
[0.1, 0.8, 0.1],
|
| 166 |
+
[0.2, 0.2, 0.6],
|
| 167 |
+
[0.2, 0.4, 0.4],
|
| 168 |
+
[0.2, 0.6, 0.2],
|
| 169 |
+
[0.33, 0.33, 0.33],
|
| 170 |
+
[0.4, 0.4, 0.2],
|
| 171 |
+
[0.4, 0.2, 0.4],
|
| 172 |
+
[0.6, 0.2, 0.2],
|
| 173 |
+
[0.8, 0.1, 0.1],
|
| 174 |
+
[1.0, 0.0, 0.0],
|
| 175 |
+
])
|
| 176 |
+
elif reward_models.num_rewards == 2:
|
| 177 |
+
M = 10
|
| 178 |
+
preferences = np.zeros((M+1, 2))
|
| 179 |
+
preferences[:, 0] = np.arange(0,1+ 1/M,1 / M)
|
| 180 |
+
preferences[:, 1] = 1 - preferences[:, 0]
|
| 181 |
+
preferences = np.round(preferences, 1)
|
| 182 |
+
else:
|
| 183 |
+
raise NotImplementedError
|
| 184 |
+
|
| 185 |
+
for k in range(0, len(preferences)):
|
| 186 |
+
preference = preferences[k]
|
| 187 |
+
temp_save_path = './temp_models/{}'.format(script_args.wandb_name)
|
| 188 |
+
if process_id == 0:
|
| 189 |
+
if len(preference) == 3:
|
| 190 |
+
base_model_list = [base_model_name_1, base_model_name_2, base_model_name_3]
|
| 191 |
+
else:
|
| 192 |
+
base_model_list = [base_model_name_1, base_model_name_2]
|
| 193 |
+
merge_weights_with_preference(base_model_list, preference, temp_save_path)
|
| 194 |
+
|
| 195 |
+
accelerator.wait_for_everyone()
|
| 196 |
+
gc.collect()
|
| 197 |
+
torch.cuda.empty_cache()
|
| 198 |
+
print(k, preference)
|
| 199 |
+
all_rewards, all_full_prompts, all_full_responses = evaluate_model(temp_save_path, tokenizer, valid_dataset)
|
| 200 |
+
gc.collect()
|
| 201 |
+
torch.cuda.empty_cache()
|
| 202 |
+
|
| 203 |
+
if process_id == 0:
|
| 204 |
+
evaluation_result = {
|
| 205 |
+
'prompt': all_full_prompts,
|
| 206 |
+
'response': all_full_responses,
|
| 207 |
+
}
|
| 208 |
+
for i in range(reward_models.num_rewards):
|
| 209 |
+
evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
|
| 210 |
+
print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
|
| 211 |
+
|
| 212 |
+
dataframe = pd.DataFrame(evaluation_result)
|
| 213 |
+
if len(preference) == 2:
|
| 214 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}.csv'.format(preference[0], preference[1])))
|
| 215 |
+
else:
|
| 216 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}_{}.csv'.format(preference[0], preference[1], preference[2])))
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
ppo/morlhf-llama3.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from accelerate import Accelerator
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import AutoTokenizer, HfArgumentParser
|
| 9 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from utils import print_trainable_parameters, Instructions, Instructions_summary, \
|
| 13 |
+
load_llama3_tokenizer, build_dataset_llama3
|
| 14 |
+
from multi_reward_models import RewardModels
|
| 15 |
+
tqdm.pandas()
|
| 16 |
+
from peft import LoraConfig
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
|
| 19 |
+
# define paths for two datasets
|
| 20 |
+
hhrlhf_dataset_path = '../datasets/hh-rlhf'
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ScriptArguments:
|
| 24 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 25 |
+
disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
|
| 26 |
+
save_directory: Optional[str] = field(default='./logs_morlhf/')
|
| 27 |
+
epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
|
| 28 |
+
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
| 29 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
| 30 |
+
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size64"})
|
| 31 |
+
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
|
| 32 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
|
| 33 |
+
target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"})
|
| 34 |
+
init_kl_coef: Optional[float] = field(default=0.2,metadata={"help": "0.05 Initial KL penalty coefficient (used for adaptive and linear control)"},)
|
| 35 |
+
max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
|
| 36 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
|
| 37 |
+
preference: Optional[str] = field(default=None)
|
| 38 |
+
wandb_name: Optional[str] = field(default='morlhf_llamma2_klreg0.2', metadata={"help": "Name for this experiment"})
|
| 39 |
+
base_model_name: Optional[str] = field(default='./merged_sft_summary', metadata={'help':"the path to the sft model; need to merge if using lora"})
|
| 40 |
+
reward_names:Optional[str] = field(default='harmless,helpful,humor')
|
| 41 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
|
| 42 |
+
|
| 43 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 44 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 45 |
+
exp_type = script_args.exp_type
|
| 46 |
+
|
| 47 |
+
tokenier_name = script_args.base_model_name
|
| 48 |
+
base_model_name = script_args.base_model_name
|
| 49 |
+
print('base model: ', base_model_name)
|
| 50 |
+
|
| 51 |
+
if script_args.disable_wandb: # if you don't need the wandb log
|
| 52 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
| 53 |
+
|
| 54 |
+
reward_names = [x.strip() for x in script_args.reward_names.split(',')]
|
| 55 |
+
num_rewards = len(reward_names)
|
| 56 |
+
print('number of rewards: {}'.format(num_rewards))
|
| 57 |
+
#######################################################################
|
| 58 |
+
assert num_rewards == 3
|
| 59 |
+
if script_args.preference is None:
|
| 60 |
+
# 均分preference
|
| 61 |
+
preference = [round(1 / num_rewards, 2) for _ in range(num_rewards)]
|
| 62 |
+
else:
|
| 63 |
+
preference = [float(x) for x in script_args.preference.split(",")]
|
| 64 |
+
print('preference: {}'.format(preference))
|
| 65 |
+
#######################################################################
|
| 66 |
+
reward_path_tokenizer_dict = {
|
| 67 |
+
'harmless': ['../reward_models/gpt2-large-harmless-reward_model'],
|
| 68 |
+
'helpful': ['../reward_models/gpt2-large-helpful-reward_model'],
|
| 69 |
+
'humor': ['../reward_models/humor-no-humor'],
|
| 70 |
+
}
|
| 71 |
+
reward_model_path_list = []
|
| 72 |
+
rm_tokenizer_path_list = []
|
| 73 |
+
for name in reward_names:
|
| 74 |
+
if name not in reward_path_tokenizer_dict.keys():
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 77 |
+
rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 78 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
config = PPOConfig(
|
| 82 |
+
model_name=base_model_name,
|
| 83 |
+
learning_rate=script_args.learning_rate,
|
| 84 |
+
# log_with=script_args.log_with,
|
| 85 |
+
mini_batch_size=script_args.mini_batch_size,
|
| 86 |
+
batch_size=script_args.batch_size,
|
| 87 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 88 |
+
early_stopping=script_args.early_stopping,
|
| 89 |
+
target=script_args.target,
|
| 90 |
+
max_grad_norm=script_args.max_grad_norm,
|
| 91 |
+
optimize_cuda_cache=True,
|
| 92 |
+
init_kl_coef=script_args.init_kl_coef,
|
| 93 |
+
# tracker_project_name='morlhf',
|
| 94 |
+
# tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
|
| 95 |
+
remove_unused_columns=False # 防止‘reward_query’没了
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
accelerator = Accelerator()
|
| 99 |
+
process_id = Accelerator().local_process_index
|
| 100 |
+
gpu_id = process_id
|
| 101 |
+
print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
############## load reward models
|
| 105 |
+
reward_model = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id)
|
| 106 |
+
rm_tokenizer = AutoTokenizer.from_pretrained(rm_tokenizer_path_list[0])
|
| 107 |
+
|
| 108 |
+
def collator(data):
|
| 109 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
| 110 |
+
|
| 111 |
+
set_seed(8888)
|
| 112 |
+
current_device = Accelerator().local_process_index
|
| 113 |
+
print(current_device)
|
| 114 |
+
|
| 115 |
+
lora_config = LoraConfig(
|
| 116 |
+
r=64,
|
| 117 |
+
lora_alpha=128,
|
| 118 |
+
lora_dropout=0.05,
|
| 119 |
+
bias="none",
|
| 120 |
+
task_type="CAUSAL_LM",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
tokenizer = load_llama3_tokenizer(tokenier_name)
|
| 125 |
+
if exp_type == 'assistant':
|
| 126 |
+
dataset = build_dataset_llama3(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 127 |
+
# instructions = Instructions()
|
| 128 |
+
else:
|
| 129 |
+
raise NotImplementedError
|
| 130 |
+
# dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 131 |
+
# instructions = Instructions_summary()
|
| 132 |
+
train_dataset = dataset.shuffle()
|
| 133 |
+
print(f"Size of the train set: {len(train_dataset)}")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if script_args.load_in_8bit:
|
| 137 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 138 |
+
base_model_name,
|
| 139 |
+
load_in_8bit=True,
|
| 140 |
+
peft_config=lora_config,
|
| 141 |
+
device_map=gpu_id,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 145 |
+
base_model_name,
|
| 146 |
+
torch_dtype=torch.bfloat16,
|
| 147 |
+
peft_config=lora_config,
|
| 148 |
+
device_map=gpu_id,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
print_trainable_parameters(model)
|
| 152 |
+
model.pretrained_model.resize_token_embeddings(len(tokenizer))
|
| 153 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
| 154 |
+
|
| 155 |
+
ppo_trainer = PPOTrainer(
|
| 156 |
+
config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# 只进行一轮生成
|
| 160 |
+
terminators = [
|
| 161 |
+
tokenizer.eos_token_id,
|
| 162 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 163 |
+
]
|
| 164 |
+
generation_kwargs = {
|
| 165 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 166 |
+
"min_length": -1,
|
| 167 |
+
"top_k": 0.0,
|
| 168 |
+
"top_p": 1.0,
|
| 169 |
+
"do_sample": True,
|
| 170 |
+
"temperature": 0.7,
|
| 171 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 172 |
+
"begin_suppress_tokens": [tokenizer.eos_token_id] ,
|
| 173 |
+
'eos_token_id': terminators
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
print("Training........")
|
| 178 |
+
model.gradient_checkpointing_disable()
|
| 179 |
+
model.pretrained_model.config.use_cache = True
|
| 180 |
+
epochs = script_args.epochs
|
| 181 |
+
mean_scores = []
|
| 182 |
+
std_scores = []
|
| 183 |
+
save_data = {
|
| 184 |
+
'kl_mean': [],
|
| 185 |
+
'reward_mean': [],
|
| 186 |
+
'reward_std': [],
|
| 187 |
+
'text_sample':[],
|
| 188 |
+
'batch_time':[],
|
| 189 |
+
'total_time':[],
|
| 190 |
+
}
|
| 191 |
+
t_start = time.time()
|
| 192 |
+
for epoch in range(epochs):
|
| 193 |
+
pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
|
| 194 |
+
for i, batch in enumerate(ppo_trainer.dataloader):
|
| 195 |
+
# print(batch.keys())
|
| 196 |
+
t_epoch_start = time.time()
|
| 197 |
+
print('epoch {}, batch {}'.format(epoch, i))
|
| 198 |
+
query_tensors = batch["input_ids"]
|
| 199 |
+
|
| 200 |
+
model.gradient_checkpointing_disable()
|
| 201 |
+
model.pretrained_model.config.use_cache = True
|
| 202 |
+
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
|
| 205 |
+
|
| 206 |
+
# 看ppo_trainer中的实现,pad token已经被去除干净了
|
| 207 |
+
full_responses_clean = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) # skip后,不用手动去除
|
| 208 |
+
|
| 209 |
+
# full_responses = tokenizer.batch_decode(response_tensors)
|
| 210 |
+
# full_responses_clean = []
|
| 211 |
+
# for _, response in enumerate(full_responses):
|
| 212 |
+
# response = response.strip('[PAD] ')
|
| 213 |
+
# response = response.strip('<unk>')
|
| 214 |
+
# temp_resp = response.strip('<s>').strip('</s>')
|
| 215 |
+
# temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 216 |
+
# temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 217 |
+
# temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 218 |
+
# temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 219 |
+
# temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 220 |
+
# temp_resp = temp_resp.split('###')[0].strip()
|
| 221 |
+
# full_responses_clean.append(temp_resp)
|
| 222 |
+
|
| 223 |
+
clean_texts = full_responses_clean
|
| 224 |
+
clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
|
| 225 |
+
|
| 226 |
+
# clean后的response map回去
|
| 227 |
+
lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
|
| 228 |
+
response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
|
| 229 |
+
batch['response'] = clean_texts
|
| 230 |
+
|
| 231 |
+
# Compute score
|
| 232 |
+
# texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
|
| 233 |
+
# # 从整个回复中分解prompt和response
|
| 234 |
+
# queries_responses = [
|
| 235 |
+
# (instructions.get_input(text), instructions.get_response(text))
|
| 236 |
+
# for text in texts_merge
|
| 237 |
+
# ]
|
| 238 |
+
# if hasattr(instructions, 'get_post'):
|
| 239 |
+
# rewards_list = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)
|
| 240 |
+
# else:
|
| 241 |
+
# rewards_list = reward_model.get_reward_model_scores(queries_responses)
|
| 242 |
+
|
| 243 |
+
qa_lsit = [(q, r) for q, r in zip(batch['reward_query'], batch['response'])]
|
| 244 |
+
rewards_list = reward_model.get_reward_model_scores(qa_lsit)
|
| 245 |
+
rewards = []
|
| 246 |
+
for j in range(len(qa_lsit)):
|
| 247 |
+
rewards.append(round(sum([preference[k] * rewards_list[k][j] for k in range(num_rewards)]), 2))
|
| 248 |
+
rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
|
| 249 |
+
print("iter {}, batch {}, mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
|
| 250 |
+
|
| 251 |
+
model.gradient_checkpointing_enable()
|
| 252 |
+
model.pretrained_model.config.use_cache = False
|
| 253 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
|
| 254 |
+
policy_kl = [stats["objective/kl"]]
|
| 255 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
| 256 |
+
|
| 257 |
+
all_rewards = accelerator.gather_for_metrics(rewards)
|
| 258 |
+
all_policy_kl = accelerator.gather_for_metrics(policy_kl)
|
| 259 |
+
if process_id == 0:
|
| 260 |
+
mean_scores.append(torch.mean(torch.tensor(all_rewards)).item())
|
| 261 |
+
std_scores.append(torch.std(torch.tensor(all_rewards)).item())
|
| 262 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
|
| 263 |
+
plt.plot(mean_scores)
|
| 264 |
+
plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores)- np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
|
| 265 |
+
plt.savefig(save_path)
|
| 266 |
+
t_epoch_end = time.time()
|
| 267 |
+
save_data['batch_time'].append(t_epoch_end - t_epoch_start)
|
| 268 |
+
save_data['total_time'].append(t_epoch_end - t_start)
|
| 269 |
+
save_data['kl_mean'].append(np.mean(all_policy_kl))
|
| 270 |
+
save_data['reward_mean'] = mean_scores
|
| 271 |
+
save_data['reward_std'] = std_scores
|
| 272 |
+
save_data['text_sample'].append(batch['query'][0] + batch['response'][0])
|
| 273 |
+
# save_data['text_sample'].append(texts_merge[0])
|
| 274 |
+
dataframe = pd.DataFrame(save_data)
|
| 275 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
|
| 276 |
+
print("iter {}, batch {}: log finish".format(epoch, i))
|
| 277 |
+
|
| 278 |
+
# wait for the main process
|
| 279 |
+
accelerator.wait_for_everyone()
|
| 280 |
+
pbar.update(1)
|
| 281 |
+
|
| 282 |
+
# save model
|
| 283 |
+
if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
|
| 284 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}-batch_{}'.format(epoch, i))
|
| 285 |
+
ppo_trainer.save_pretrained(save_path)
|
| 286 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 287 |
+
|
| 288 |
+
# save model
|
| 289 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 290 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}'.format(epoch))
|
| 291 |
+
ppo_trainer.save_pretrained(save_path)
|
| 292 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 293 |
+
|
ppo/morlhf.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from accelerate import Accelerator
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import AutoTokenizer, HfArgumentParser
|
| 9 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
|
| 13 |
+
build_dataset, build_dataset_summary
|
| 14 |
+
from multi_reward_models import RewardModels
|
| 15 |
+
tqdm.pandas()
|
| 16 |
+
from peft import LoraConfig
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
|
| 19 |
+
# define paths for two datasets
|
| 20 |
+
hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
|
| 21 |
+
summary_dataset_path = 'openai/summarize_from_feedback'
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ScriptArguments:
|
| 25 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 26 |
+
disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
|
| 27 |
+
save_directory: Optional[str] = field(default='./logs_morlhf/')
|
| 28 |
+
epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
|
| 29 |
+
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
| 30 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
| 31 |
+
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size64"})
|
| 32 |
+
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
|
| 33 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
|
| 34 |
+
target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"})
|
| 35 |
+
init_kl_coef: Optional[float] = field(default=0.2,metadata={"help": "0.05 Initial KL penalty coefficient (used for adaptive and linear control)"},)
|
| 36 |
+
max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
|
| 37 |
+
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "loading model in 8 bit or bfloat16"})
|
| 38 |
+
preference: Optional[float] = field(default=0.5, metadata={"help": "the weight for reward 1"})
|
| 39 |
+
wandb_name: Optional[str] = field(default='morlhf_llamma2_klreg0.2', metadata={"help": "Name for this experiment"})
|
| 40 |
+
base_model_name: Optional[str] = field(default='./merged_sft_summary', metadata={'help':"the path to the sft model; need to merge if using lora"})
|
| 41 |
+
reward_names:Optional[str] = field(default='harmless,helpful,humor')
|
| 42 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
|
| 43 |
+
|
| 44 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 45 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 46 |
+
exp_type = script_args.exp_type
|
| 47 |
+
preference = [round(script_args.preference, 1), round(1 - script_args.preference, 1)]
|
| 48 |
+
script_args.wandb_name = script_args.wandb_name + '_pref{}_{}'.format(preference[0], preference[1])
|
| 49 |
+
|
| 50 |
+
tokenier_name = script_args.base_model_name
|
| 51 |
+
base_model_name = script_args.base_model_name
|
| 52 |
+
print('base model: ', base_model_name)
|
| 53 |
+
|
| 54 |
+
if script_args.disable_wandb: # if you don't need the wandb log
|
| 55 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
| 56 |
+
|
| 57 |
+
reward_names = [x.strip() for x in script_args.reward_names.split(',')]
|
| 58 |
+
num_rewards = len(reward_names)
|
| 59 |
+
print('number of rewards: {}'.format(num_rewards))
|
| 60 |
+
#######################################################################
|
| 61 |
+
if num_rewards == 3:
|
| 62 |
+
preference = [round(1 / num_rewards, 2) for _ in range(num_rewards)]
|
| 63 |
+
print('preference: {}'.format(preference))
|
| 64 |
+
#######################################################################
|
| 65 |
+
reward_path_tokenizer_dict = {
|
| 66 |
+
'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
|
| 67 |
+
'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
|
| 68 |
+
'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
|
| 69 |
+
'summary': ['Tristan/gpt2_reward_summarization'],
|
| 70 |
+
'faithful':['CogComp/bart-faithful-summary-detector'],
|
| 71 |
+
'humor': ['mohameddhiab/humor-no-humor'],
|
| 72 |
+
}
|
| 73 |
+
reward_model_path_list = []
|
| 74 |
+
rm_tokenizer_path_list = []
|
| 75 |
+
for name in reward_names:
|
| 76 |
+
if name not in reward_path_tokenizer_dict.keys():
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 79 |
+
rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 80 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
config = PPOConfig(
|
| 84 |
+
model_name=base_model_name,
|
| 85 |
+
learning_rate=script_args.learning_rate,
|
| 86 |
+
log_with=script_args.log_with,
|
| 87 |
+
mini_batch_size=script_args.mini_batch_size,
|
| 88 |
+
batch_size=script_args.batch_size,
|
| 89 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 90 |
+
early_stopping=script_args.early_stopping,
|
| 91 |
+
target=script_args.target,
|
| 92 |
+
max_grad_norm=script_args.max_grad_norm,
|
| 93 |
+
optimize_cuda_cache=True,
|
| 94 |
+
init_kl_coef=script_args.init_kl_coef,
|
| 95 |
+
tracker_project_name='morlhf',
|
| 96 |
+
tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
accelerator = Accelerator()
|
| 100 |
+
process_id = Accelerator().local_process_index
|
| 101 |
+
gpu_id = process_id
|
| 102 |
+
print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
############## load reward models
|
| 106 |
+
reward_model = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id)
|
| 107 |
+
rm_tokenizer = AutoTokenizer.from_pretrained(rm_tokenizer_path_list[0])
|
| 108 |
+
|
| 109 |
+
def collator(data):
|
| 110 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
| 111 |
+
|
| 112 |
+
set_seed(8888)
|
| 113 |
+
current_device = Accelerator().local_process_index
|
| 114 |
+
print(current_device)
|
| 115 |
+
|
| 116 |
+
lora_config = LoraConfig(
|
| 117 |
+
r=64,
|
| 118 |
+
lora_alpha=128,
|
| 119 |
+
lora_dropout=0.05,
|
| 120 |
+
bias="none",
|
| 121 |
+
task_type="CAUSAL_LM",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 126 |
+
if exp_type == 'assistant':
|
| 127 |
+
dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 128 |
+
instructions = Instructions()
|
| 129 |
+
else:
|
| 130 |
+
dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 131 |
+
instructions = Instructions_summary()
|
| 132 |
+
train_dataset = dataset.shuffle()
|
| 133 |
+
print(f"Size of the train set: {len(train_dataset)}")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if script_args.load_in_8bit:
|
| 137 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 138 |
+
base_model_name,
|
| 139 |
+
load_in_8bit=True,
|
| 140 |
+
peft_config=lora_config,
|
| 141 |
+
device_map=gpu_id,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 145 |
+
base_model_name,
|
| 146 |
+
torch_dtype=torch.bfloat16,
|
| 147 |
+
peft_config=lora_config,
|
| 148 |
+
device_map=gpu_id,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
print_trainable_parameters(model)
|
| 152 |
+
model.pretrained_model.resize_token_embeddings(len(tokenizer))
|
| 153 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
| 154 |
+
|
| 155 |
+
ppo_trainer = PPOTrainer(
|
| 156 |
+
config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
generation_kwargs = {
|
| 160 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 161 |
+
"min_length": -1,
|
| 162 |
+
"top_k": 0.0,
|
| 163 |
+
"top_p": 1.0,
|
| 164 |
+
"do_sample": True,
|
| 165 |
+
"temperature": 0.7,
|
| 166 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 167 |
+
"begin_suppress_tokens": [tokenizer.eos_token_id] ,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
print("Training........")
|
| 172 |
+
model.gradient_checkpointing_disable()
|
| 173 |
+
model.pretrained_model.config.use_cache = True
|
| 174 |
+
epochs = script_args.epochs
|
| 175 |
+
mean_scores = []
|
| 176 |
+
std_scores = []
|
| 177 |
+
save_data = {
|
| 178 |
+
'kl_mean': [],
|
| 179 |
+
'reward_mean': [],
|
| 180 |
+
'reward_std': [],
|
| 181 |
+
'text_sample':[],
|
| 182 |
+
'batch_time':[],
|
| 183 |
+
'total_time':[],
|
| 184 |
+
}
|
| 185 |
+
t_start = time.time()
|
| 186 |
+
for epoch in range(epochs):
|
| 187 |
+
pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
|
| 188 |
+
for i, batch in enumerate(ppo_trainer.dataloader):
|
| 189 |
+
t_epoch_start = time.time()
|
| 190 |
+
print('epoch {}, batch {}'.format(epoch, i))
|
| 191 |
+
query_tensors = batch["input_ids"]
|
| 192 |
+
|
| 193 |
+
model.gradient_checkpointing_disable()
|
| 194 |
+
model.pretrained_model.config.use_cache = True
|
| 195 |
+
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
|
| 198 |
+
|
| 199 |
+
full_responses = tokenizer.batch_decode(response_tensors)
|
| 200 |
+
full_responses_clean = []
|
| 201 |
+
for _, response in enumerate(full_responses):
|
| 202 |
+
response = response.strip('[PAD] ')
|
| 203 |
+
response = response.strip('<unk>')
|
| 204 |
+
temp_resp = response.strip('<s>').strip('</s>')
|
| 205 |
+
temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 206 |
+
temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 207 |
+
temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 208 |
+
temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 209 |
+
temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 210 |
+
temp_resp = temp_resp.split('###')[0].strip()
|
| 211 |
+
full_responses_clean.append(temp_resp)
|
| 212 |
+
|
| 213 |
+
clean_texts = full_responses_clean
|
| 214 |
+
clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
|
| 215 |
+
|
| 216 |
+
lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
|
| 217 |
+
response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
|
| 218 |
+
batch['response'] = clean_texts
|
| 219 |
+
|
| 220 |
+
# Compute score
|
| 221 |
+
texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
|
| 222 |
+
queries_responses = [
|
| 223 |
+
(instructions.get_input(text), instructions.get_response(text))
|
| 224 |
+
for text in texts_merge
|
| 225 |
+
]
|
| 226 |
+
if hasattr(instructions, 'get_post'):
|
| 227 |
+
rewards_list = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)
|
| 228 |
+
else:
|
| 229 |
+
rewards_list = reward_model.get_reward_model_scores(queries_responses)
|
| 230 |
+
rewards = []
|
| 231 |
+
for j in range(len(queries_responses)):
|
| 232 |
+
rewards.append(round(sum([preference[k] * rewards_list[k][j] for k in range(num_rewards)]), 2))
|
| 233 |
+
rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
|
| 234 |
+
print("iter {}, batch {}, mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
|
| 235 |
+
|
| 236 |
+
model.gradient_checkpointing_enable()
|
| 237 |
+
model.pretrained_model.config.use_cache = False
|
| 238 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
|
| 239 |
+
policy_kl = [stats["objective/kl"]]
|
| 240 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
| 241 |
+
|
| 242 |
+
all_rewards = accelerator.gather_for_metrics(rewards)
|
| 243 |
+
all_policy_kl = accelerator.gather_for_metrics(policy_kl)
|
| 244 |
+
if process_id == 0:
|
| 245 |
+
mean_scores.append(torch.mean(torch.tensor(all_rewards)).item())
|
| 246 |
+
std_scores.append(torch.std(torch.tensor(all_rewards)).item())
|
| 247 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
|
| 248 |
+
plt.plot(mean_scores)
|
| 249 |
+
plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores)- np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
|
| 250 |
+
plt.savefig(save_path)
|
| 251 |
+
t_epoch_end = time.time()
|
| 252 |
+
save_data['batch_time'].append(t_epoch_end - t_epoch_start)
|
| 253 |
+
save_data['total_time'].append(t_epoch_end - t_start)
|
| 254 |
+
save_data['kl_mean'].append(np.mean(all_policy_kl))
|
| 255 |
+
save_data['reward_mean'] = mean_scores
|
| 256 |
+
save_data['reward_std'] = std_scores
|
| 257 |
+
save_data['text_sample'].append(texts_merge[0])
|
| 258 |
+
dataframe = pd.DataFrame(save_data)
|
| 259 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
|
| 260 |
+
print("iter {}, batch {}: log finish".format(epoch, i))
|
| 261 |
+
|
| 262 |
+
# wait for the main process
|
| 263 |
+
accelerator.wait_for_everyone()
|
| 264 |
+
pbar.update(1)
|
| 265 |
+
|
| 266 |
+
# save model
|
| 267 |
+
if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
|
| 268 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}-batch_{}'.format(epoch, i))
|
| 269 |
+
ppo_trainer.save_pretrained(save_path)
|
| 270 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 271 |
+
|
| 272 |
+
# save model
|
| 273 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 274 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'batch_{}'.format(i))
|
| 275 |
+
ppo_trainer.save_pretrained(save_path)
|
| 276 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 277 |
+
|
ppo/multi_reward_models.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from utils import load_reward_model, get_rewards
|
| 6 |
+
|
| 7 |
+
class RewardModels():
|
| 8 |
+
def __init__(self, reward_model_path_list, rm_tokenizer_path_list, gpu_id_list, reward_stats_path=None):
|
| 9 |
+
assert len(reward_model_path_list) == len(rm_tokenizer_path_list)
|
| 10 |
+
self.reward_model_path_list = reward_model_path_list
|
| 11 |
+
self.rm_tokenizer_path_list = rm_tokenizer_path_list
|
| 12 |
+
self.num_rewards = len(reward_model_path_list)
|
| 13 |
+
self.reward_stats = np.load(reward_stats_path) if reward_stats_path is not None else None
|
| 14 |
+
self.reward_models = []
|
| 15 |
+
self.rm_tokenizers = []
|
| 16 |
+
if type(gpu_id_list) != list:
|
| 17 |
+
gpu_id_list = [gpu_id_list, gpu_id_list, gpu_id_list]
|
| 18 |
+
|
| 19 |
+
print('Loading reward models .....')
|
| 20 |
+
for i in range(self.num_rewards):
|
| 21 |
+
self.reward_models.append(load_reward_model(self.reward_model_path_list[i], gpu_id_list[i]))
|
| 22 |
+
self.rm_tokenizers.append(AutoTokenizer.from_pretrained(self.rm_tokenizer_path_list[i]))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_reward_model_scores(self, queries_responses, summary_fun=None):
|
| 26 |
+
texts_for_rewards = []
|
| 27 |
+
for i in range(self.num_rewards):
|
| 28 |
+
if i >= 1 and self.rm_tokenizer_path_list[i] == self.rm_tokenizer_path_list[i-1]:
|
| 29 |
+
texts_for_rewards.append(texts_for_rewards[-1])
|
| 30 |
+
elif 'faithful' in self.reward_model_path_list[i]:
|
| 31 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 32 |
+
temp_encoded_texts = [self.rm_tokenizers[i](text=r, text_pair=summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 33 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 34 |
+
elif 'summary' in self.reward_model_path_list[i] or 'summarization' in self.reward_model_path_list[i]: # reverse prompt and response
|
| 35 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 36 |
+
temp_encoded_texts = [self.rm_tokenizers[i](r + " " + self.rm_tokenizers[i].bos_token + " " + summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 37 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 38 |
+
elif 'humor' in self.reward_model_path_list[i]: # use only response
|
| 39 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 40 |
+
temp_encoded_texts = [self.rm_tokenizers[i](r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 41 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 42 |
+
else:
|
| 43 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 44 |
+
temp_encoded_texts = [self.rm_tokenizers[i](q, r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 45 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 46 |
+
|
| 47 |
+
# normalize reward
|
| 48 |
+
rewards = []
|
| 49 |
+
for i in range(self.num_rewards):
|
| 50 |
+
if self.reward_stats is not None:
|
| 51 |
+
if type(self.reward_stats) == list or len(self.reward_stats) == 2 * self.num_rewards:
|
| 52 |
+
reward_mean_std = (self.reward_stats[2*i], self.reward_stats[2*i+1])
|
| 53 |
+
else:
|
| 54 |
+
reward_mean_std = self.reward_stats[i]
|
| 55 |
+
else:
|
| 56 |
+
reward_mean_std = None
|
| 57 |
+
|
| 58 |
+
if 'humor' in self.reward_model_path_list[i] or 'faithful' in self.reward_model_path_list[i]:
|
| 59 |
+
temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std, sub_position=1)
|
| 60 |
+
else:
|
| 61 |
+
temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std)
|
| 62 |
+
rewards.append(temp_reward)
|
| 63 |
+
return rewards
|
| 64 |
+
|
ppo/ppo.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import HfArgumentParser
|
| 8 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
|
| 12 |
+
build_dataset, build_dataset_summary
|
| 13 |
+
from multi_reward_models import RewardModels
|
| 14 |
+
tqdm.pandas()
|
| 15 |
+
from peft import LoraConfig
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
from transformers import LlamaForCausalLM
|
| 18 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 19 |
+
# define paths for two datasets
|
| 20 |
+
hhrlhf_dataset_path = '../datasets/hh-rlhf'
|
| 21 |
+
# summary_dataset_path = 'openai/summarize_from_feedback'
|
| 22 |
+
|
| 23 |
+
class BlockWrapper(torch.nn.Module):
|
| 24 |
+
def __init__(self, block, vec=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiplier = 1.0
|
| 27 |
+
self.block = block
|
| 28 |
+
if vec is not None:
|
| 29 |
+
self.vec = torch.nn.Parameter(vec)
|
| 30 |
+
else:
|
| 31 |
+
# Zero Init
|
| 32 |
+
self.vec = torch.nn.Parameter(torch.zeros(4096, dtype=torch.bfloat16))
|
| 33 |
+
|
| 34 |
+
def forward(self, *args, **kwargs):
|
| 35 |
+
output = self.block(*args, **kwargs)
|
| 36 |
+
output = (output[0] + (self.multiplier * self.vec),) + output[1:]
|
| 37 |
+
return output
|
| 38 |
+
|
| 39 |
+
def set_multiplier(self, multiplier):
|
| 40 |
+
self.multiplier = multiplier
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class ScriptArguments:
|
| 44 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 45 |
+
disable_wandb: Optional[str] = field(default=True, metadata={'help': 'Whether to disable wandb or not.'})
|
| 46 |
+
save_directory: Optional[str] = field(default='./ckpt/')
|
| 47 |
+
epochs: Optional[int] = field(default=3, metadata={'help': "Number of training epoches"})
|
| 48 |
+
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
| 49 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
| 50 |
+
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
|
| 51 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
|
| 52 |
+
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
|
| 53 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
|
| 54 |
+
target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"}) # maybe 12
|
| 55 |
+
init_kl_coef: Optional[float] = field(default=0.2,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
|
| 56 |
+
max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
|
| 57 |
+
wandb_name: Optional[str] = field(default='new', metadata={"help": "Name for this experiment"})
|
| 58 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type: 'assistant" or 'summary'})
|
| 59 |
+
base_model_name: Optional[str] = field(default=None, metadata={'help':"the path to the sft model; need to merge if using lora"})
|
| 60 |
+
reward_name: Optional[str] = field(default='helpful')
|
| 61 |
+
layers: Optional[str] = field(default='15', metadata={"help": "the layers the steering vector extracted from: 12,13,14,15"})
|
| 62 |
+
|
| 63 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 64 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 65 |
+
exp_type = script_args.exp_type
|
| 66 |
+
# Remember to use a merged sft model if using lora
|
| 67 |
+
base_model_name = script_args.base_model_name
|
| 68 |
+
tokenier_name = script_args.base_model_name
|
| 69 |
+
print('base model: ', base_model_name)
|
| 70 |
+
|
| 71 |
+
script_args.layers = [int(layer) for layer in script_args.layers.split(',')]
|
| 72 |
+
|
| 73 |
+
if script_args.disable_wandb: # if you don't need the wandb log
|
| 74 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
| 75 |
+
|
| 76 |
+
reward_name = script_args.reward_name
|
| 77 |
+
# if reward_name == 'summary':
|
| 78 |
+
# reward_peft_path = 'Tristan/gpt2_reward_summarization'
|
| 79 |
+
# elif reward_name == 'faithful':
|
| 80 |
+
# reward_peft_path = 'CogComp/bart-faithful-summary-detector'
|
| 81 |
+
if reward_name == 'helpful':
|
| 82 |
+
reward_peft_path = '../reward_models/gpt2-large-helpful-reward_model'
|
| 83 |
+
elif reward_name == 'harmless':
|
| 84 |
+
reward_peft_path = '../reward_models/gpt2-large-harmless-reward_model'
|
| 85 |
+
# elif reward_name == 'deberta':
|
| 86 |
+
# reward_peft_path = 'OpenAssistant/reward-model-deberta-v3-large-v2'
|
| 87 |
+
elif reward_name == 'humor':
|
| 88 |
+
reward_peft_path = '../reward_models/humor-no-humor'
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
rm_tokenizer_path = reward_peft_path
|
| 92 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
config = PPOConfig(
|
| 96 |
+
model_name=base_model_name,
|
| 97 |
+
learning_rate=script_args.learning_rate,
|
| 98 |
+
# log_with=script_args.log_with,
|
| 99 |
+
mini_batch_size=script_args.mini_batch_size,
|
| 100 |
+
batch_size=script_args.batch_size,
|
| 101 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 102 |
+
early_stopping=script_args.early_stopping,
|
| 103 |
+
target=script_args.target,
|
| 104 |
+
max_grad_norm=script_args.max_grad_norm,
|
| 105 |
+
optimize_cuda_cache=True,
|
| 106 |
+
init_kl_coef=script_args.init_kl_coef
|
| 107 |
+
# tracker_project_name='ppo',
|
| 108 |
+
# tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
accelerator = Accelerator()
|
| 112 |
+
process_id = Accelerator().local_process_index
|
| 113 |
+
gpu_id = process_id
|
| 114 |
+
print('process: {}'.format(process_id))
|
| 115 |
+
reward_model = RewardModels([reward_peft_path], [rm_tokenizer_path], gpu_id)
|
| 116 |
+
rm_tokenizer = reward_model.rm_tokenizers[0]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# set seed before initializing value head for deterministic eval
|
| 120 |
+
set_seed(8888)
|
| 121 |
+
current_device = Accelerator().local_process_index
|
| 122 |
+
print(current_device)
|
| 123 |
+
|
| 124 |
+
# lora_config = LoraConfig(
|
| 125 |
+
# r=64,
|
| 126 |
+
# lora_alpha=128,
|
| 127 |
+
# lora_dropout=0.05,
|
| 128 |
+
# bias="none",
|
| 129 |
+
# task_type="CAUSAL_LM",
|
| 130 |
+
# )
|
| 131 |
+
|
| 132 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 133 |
+
if exp_type == 'assistant':
|
| 134 |
+
dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 135 |
+
instructions = Instructions()
|
| 136 |
+
else:
|
| 137 |
+
raise NotImplementedError
|
| 138 |
+
# dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 139 |
+
# instructions = Instructions_summary()
|
| 140 |
+
train_dataset = dataset.shuffle()
|
| 141 |
+
print(f"Size of the train set: {len(train_dataset)}.")
|
| 142 |
+
|
| 143 |
+
if script_args.load_in_8bit:
|
| 144 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 145 |
+
base_model_name,
|
| 146 |
+
load_in_8bit=True,
|
| 147 |
+
# peft_config=lora_config,
|
| 148 |
+
device_map=gpu_id,
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 152 |
+
base_model_name,
|
| 153 |
+
torch_dtype=torch.bfloat16,
|
| 154 |
+
# peft_config=lora_config,
|
| 155 |
+
device_map=gpu_id,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
###################
|
| 159 |
+
for layer in script_args.layers:
|
| 160 |
+
model.pretrained_model.model.layers[layer] = BlockWrapper(model.pretrained_model.model.layers[layer])
|
| 161 |
+
model.pretrained_model.config.use_cache = False
|
| 162 |
+
for name, param in model.pretrained_model.named_parameters():
|
| 163 |
+
param.requires_grad = False
|
| 164 |
+
for layer in script_args.layers:
|
| 165 |
+
model.pretrained_model.model.layers[layer].vec.requires_grad = True
|
| 166 |
+
###################
|
| 167 |
+
|
| 168 |
+
print_trainable_parameters(model)
|
| 169 |
+
print('trainable params names:')
|
| 170 |
+
for name, param in model.named_parameters():
|
| 171 |
+
if param.requires_grad:
|
| 172 |
+
print(name)
|
| 173 |
+
|
| 174 |
+
model.pretrained_model.resize_token_embeddings(len(tokenizer))
|
| 175 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
| 176 |
+
def collator(data):
|
| 177 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
| 178 |
+
|
| 179 |
+
ppo_trainer = PPOTrainer(
|
| 180 |
+
config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
generation_kwargs = {
|
| 184 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 185 |
+
'min_length': -1,
|
| 186 |
+
"top_k": 0.0,
|
| 187 |
+
"top_p": 1.0,
|
| 188 |
+
"do_sample": True,
|
| 189 |
+
"temperature": 0.7,
|
| 190 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 191 |
+
"begin_suppress_tokens": [tokenizer.eos_token_id],
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
print("Training........")
|
| 195 |
+
model.gradient_checkpointing_disable()
|
| 196 |
+
model.pretrained_model.config.use_cache = True
|
| 197 |
+
|
| 198 |
+
epochs = script_args.epochs
|
| 199 |
+
mean_scores = []
|
| 200 |
+
std_scores = []
|
| 201 |
+
save_data = {
|
| 202 |
+
'kl_mean': [],
|
| 203 |
+
'kl_std': [],
|
| 204 |
+
'reward_mean': [],
|
| 205 |
+
'reward_std': [],
|
| 206 |
+
'text_sample':[],
|
| 207 |
+
}
|
| 208 |
+
for epoch in range(epochs):
|
| 209 |
+
pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
|
| 210 |
+
for i, batch in enumerate(ppo_trainer.dataloader):
|
| 211 |
+
print('epoch {}, batch {}'.format(epoch, i))
|
| 212 |
+
query_tensors = batch["input_ids"]
|
| 213 |
+
|
| 214 |
+
model.gradient_checkpointing_disable()
|
| 215 |
+
model.pretrained_model.config.use_cache = True
|
| 216 |
+
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
|
| 219 |
+
|
| 220 |
+
full_responses = tokenizer.batch_decode(response_tensors)
|
| 221 |
+
full_responses_clean = []
|
| 222 |
+
for _, response in enumerate(full_responses):
|
| 223 |
+
response = response.strip('[PAD] ')
|
| 224 |
+
response = response.strip('<unk>')
|
| 225 |
+
temp_resp = response.strip('<s>').strip('</s>')
|
| 226 |
+
temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 227 |
+
temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 228 |
+
temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 229 |
+
temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 230 |
+
temp_resp = temp_resp.split('###')[0].strip()
|
| 231 |
+
temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 232 |
+
full_responses_clean.append(temp_resp)
|
| 233 |
+
|
| 234 |
+
clean_texts = full_responses_clean
|
| 235 |
+
clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
|
| 236 |
+
|
| 237 |
+
lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
|
| 238 |
+
print(lengths)
|
| 239 |
+
response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
|
| 240 |
+
batch['response'] = clean_texts
|
| 241 |
+
|
| 242 |
+
# Compute score
|
| 243 |
+
texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
|
| 244 |
+
queries_responses = [
|
| 245 |
+
(instructions.get_input(text), instructions.get_response(text))
|
| 246 |
+
for text in texts_merge
|
| 247 |
+
]
|
| 248 |
+
if hasattr(instructions, 'get_post'):
|
| 249 |
+
rewards = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)[0]
|
| 250 |
+
else:
|
| 251 |
+
rewards = reward_model.get_reward_model_scores(queries_responses)[0]
|
| 252 |
+
rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
|
| 253 |
+
print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
|
| 254 |
+
|
| 255 |
+
model.gradient_checkpointing_enable()
|
| 256 |
+
model.pretrained_model.config.use_cache = False
|
| 257 |
+
ppo_trainer.config.batch_size = len(query_tensors)
|
| 258 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
|
| 259 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
| 260 |
+
policy_kl = [stats["objective/kl"]]
|
| 261 |
+
|
| 262 |
+
all_rewards = accelerator.gather_for_metrics(rewards)
|
| 263 |
+
all_policy_kl = accelerator.gather_for_metrics(policy_kl)
|
| 264 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 265 |
+
mean_scores.append(torch.mean(torch.tensor(rewards)).item())
|
| 266 |
+
std_scores.append(torch.std(torch.tensor(rewards)).item())
|
| 267 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
|
| 268 |
+
plt.plot(mean_scores)
|
| 269 |
+
plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores) - np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
|
| 270 |
+
plt.savefig(save_path)
|
| 271 |
+
|
| 272 |
+
save_data['kl_mean'].append(np.mean(all_policy_kl))
|
| 273 |
+
save_data['kl_std'].append(np.std(all_policy_kl))
|
| 274 |
+
save_data['reward_mean'] = mean_scores
|
| 275 |
+
save_data['reward_std'] = std_scores
|
| 276 |
+
save_data['text_sample'].append(texts_merge[0])
|
| 277 |
+
dataframe = pd.DataFrame(save_data)
|
| 278 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
|
| 279 |
+
print("iter {}, batch {}: log finish".format(epoch, i))
|
| 280 |
+
|
| 281 |
+
# wait for the main process
|
| 282 |
+
accelerator.wait_for_everyone()
|
| 283 |
+
pbar.update(1)
|
| 284 |
+
|
| 285 |
+
# save model
|
| 286 |
+
if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
|
| 287 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 288 |
+
# ppo_trainer.save_pretrained(save_path)
|
| 289 |
+
|
| 290 |
+
for layer in script_args.layers:
|
| 291 |
+
# steer_vec = ppo_trainer.model.pretrained_model.model.layers[layer].vec.detach().cpu()
|
| 292 |
+
steer_vec = model.pretrained_model.model.layers[layer].vec.detach().cpu()
|
| 293 |
+
print(f'Steer vec of layer {layer} at step {i}: ', steer_vec[:10], steer_vec.dtype)
|
| 294 |
+
torch.save(
|
| 295 |
+
steer_vec,
|
| 296 |
+
f"{save_path}/vec_layer{layer}-step{i}.pt",
|
| 297 |
+
)
|
| 298 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 299 |
+
|
| 300 |
+
# save model
|
| 301 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 302 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 303 |
+
# ppo_trainer.save_pretrained(save_path)
|
| 304 |
+
for layer in script_args.layers:
|
| 305 |
+
# steer_vec = ppo_trainer.model.pretrained_model.model.layers[layer].vec.detach().cpu()
|
| 306 |
+
steer_vec = model.pretrained_model.model.layers[layer].vec.detach().cpu()
|
| 307 |
+
print(f'Steer vec of layer {layer} at step {i}: ', steer_vec[:10], steer_vec.dtype)
|
| 308 |
+
torch.save(
|
| 309 |
+
steer_vec,
|
| 310 |
+
f"{save_path}/vec_layer{layer}-step{i}.pt",
|
| 311 |
+
)
|
| 312 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 313 |
+
|
ppo/ppo_reft.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
import torch
|
| 6 |
+
import math
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import HfArgumentParser
|
| 9 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
|
| 13 |
+
build_dataset, build_dataset_summary
|
| 14 |
+
from multi_reward_models import RewardModels
|
| 15 |
+
tqdm.pandas()
|
| 16 |
+
from peft import LoraConfig
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 19 |
+
# define paths for two datasets
|
| 20 |
+
hhrlhf_dataset_path = '../datasets/hh-rlhf'
|
| 21 |
+
# summary_dataset_path = 'openai/summarize_from_feedback'
|
| 22 |
+
|
| 23 |
+
from transformers.activations import ACT2FN
|
| 24 |
+
class BlockWrapper(torch.nn.Module):
|
| 25 |
+
def __init__(self, block, rank=64):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.multiplier = 1.0
|
| 28 |
+
self.block = block
|
| 29 |
+
self.learned_source = torch.nn.Linear(4096, rank).to(torch.bfloat16)
|
| 30 |
+
torch.nn.init.kaiming_uniform_(self.learned_source.weight, a=math.sqrt(5))
|
| 31 |
+
self.proj_layer = torch.nn.Linear(rank, 4096, bias=False).to(torch.bfloat16)
|
| 32 |
+
torch.nn.init.zeros_(self.proj_layer.weight)
|
| 33 |
+
# self.act_fn = ACT2FN["linear"] # TODO: revise
|
| 34 |
+
|
| 35 |
+
def forward(self, *args, **kwargs):
|
| 36 |
+
output = self.block(*args, **kwargs)
|
| 37 |
+
output = (output[0] + (self.multiplier * self.proj_layer(self.learned_source(output[0]))),) + output[1:]
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
def set_multiplier(self, multiplier):
|
| 41 |
+
self.multiplier = multiplier
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class ScriptArguments:
|
| 45 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 46 |
+
disable_wandb: Optional[str] = field(default=True, metadata={'help': 'Whether to disable wandb or not.'})
|
| 47 |
+
save_directory: Optional[str] = field(default='./ckpt/')
|
| 48 |
+
epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
|
| 49 |
+
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
| 50 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
| 51 |
+
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
|
| 52 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
|
| 53 |
+
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
|
| 54 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
|
| 55 |
+
target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"}) # maybe 12
|
| 56 |
+
init_kl_coef: Optional[float] = field(default=0.8,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
|
| 57 |
+
max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
|
| 58 |
+
wandb_name: Optional[str] = field(default='new', metadata={"help": "Name for this experiment"})
|
| 59 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type: 'assistant" or 'summary'})
|
| 60 |
+
base_model_name: Optional[str] = field(default=None, metadata={'help':"the path to the sft model; need to merge if using lora"})
|
| 61 |
+
reward_name: Optional[str] = field(default='helpful')
|
| 62 |
+
layers: Optional[str] = field(default='15', metadata={"help": "the layers the steering vector extracted from: 12,13,14,15"})
|
| 63 |
+
|
| 64 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 65 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 66 |
+
exp_type = script_args.exp_type
|
| 67 |
+
# Remember to use a merged sft model if using lora
|
| 68 |
+
base_model_name = script_args.base_model_name
|
| 69 |
+
tokenier_name = script_args.base_model_name
|
| 70 |
+
print('base model: ', base_model_name)
|
| 71 |
+
|
| 72 |
+
script_args.layers = [int(layer) for layer in script_args.layers.split(',')]
|
| 73 |
+
|
| 74 |
+
if script_args.disable_wandb: # if you don't need the wandb log
|
| 75 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
| 76 |
+
|
| 77 |
+
reward_name = script_args.reward_name
|
| 78 |
+
# if reward_name == 'summary':
|
| 79 |
+
# reward_peft_path = 'Tristan/gpt2_reward_summarization'
|
| 80 |
+
# elif reward_name == 'faithful':
|
| 81 |
+
# reward_peft_path = 'CogComp/bart-faithful-summary-detector'
|
| 82 |
+
if reward_name == 'helpful':
|
| 83 |
+
reward_peft_path = '../reward_models/gpt2-large-helpful-reward_model'
|
| 84 |
+
elif reward_name == 'harmless':
|
| 85 |
+
reward_peft_path = '../reward_models/gpt2-large-harmless-reward_model'
|
| 86 |
+
# elif reward_name == 'deberta':
|
| 87 |
+
# reward_peft_path = 'OpenAssistant/reward-model-deberta-v3-large-v2'
|
| 88 |
+
elif reward_name == 'humor':
|
| 89 |
+
reward_peft_path = '../reward_models/humor-no-humor'
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
rm_tokenizer_path = reward_peft_path
|
| 93 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
config = PPOConfig(
|
| 97 |
+
model_name=base_model_name,
|
| 98 |
+
learning_rate=script_args.learning_rate,
|
| 99 |
+
# log_with=script_args.log_with,
|
| 100 |
+
mini_batch_size=script_args.mini_batch_size,
|
| 101 |
+
batch_size=script_args.batch_size,
|
| 102 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 103 |
+
early_stopping=script_args.early_stopping,
|
| 104 |
+
target=script_args.target,
|
| 105 |
+
max_grad_norm=script_args.max_grad_norm,
|
| 106 |
+
optimize_cuda_cache=True,
|
| 107 |
+
init_kl_coef=script_args.init_kl_coef
|
| 108 |
+
# tracker_project_name='ppo',
|
| 109 |
+
# tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
accelerator = Accelerator()
|
| 113 |
+
process_id = Accelerator().local_process_index
|
| 114 |
+
gpu_id = process_id
|
| 115 |
+
print('process: {}'.format(process_id))
|
| 116 |
+
reward_model = RewardModels([reward_peft_path], [rm_tokenizer_path], gpu_id)
|
| 117 |
+
rm_tokenizer = reward_model.rm_tokenizers[0]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# set seed before initializing value head for deterministic eval
|
| 121 |
+
set_seed(8888)
|
| 122 |
+
current_device = Accelerator().local_process_index
|
| 123 |
+
print(current_device)
|
| 124 |
+
|
| 125 |
+
# lora_config = LoraConfig(
|
| 126 |
+
# r=64,
|
| 127 |
+
# lora_alpha=128,
|
| 128 |
+
# lora_dropout=0.05,
|
| 129 |
+
# bias="none",
|
| 130 |
+
# task_type="CAUSAL_LM",
|
| 131 |
+
# )
|
| 132 |
+
|
| 133 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 134 |
+
if exp_type == 'assistant':
|
| 135 |
+
dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 136 |
+
instructions = Instructions()
|
| 137 |
+
else:
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
# dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 140 |
+
# instructions = Instructions_summary()
|
| 141 |
+
train_dataset = dataset.shuffle()
|
| 142 |
+
print(f"Size of the train set: {len(train_dataset)}.")
|
| 143 |
+
|
| 144 |
+
if script_args.load_in_8bit:
|
| 145 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 146 |
+
base_model_name,
|
| 147 |
+
load_in_8bit=True,
|
| 148 |
+
# peft_config=lora_config,
|
| 149 |
+
device_map=gpu_id,
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 153 |
+
base_model_name,
|
| 154 |
+
torch_dtype=torch.bfloat16,
|
| 155 |
+
# peft_config=lora_config,
|
| 156 |
+
device_map=gpu_id,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
###################
|
| 160 |
+
for layer in script_args.layers:
|
| 161 |
+
model.pretrained_model.model.layers[layer] = BlockWrapper(model.pretrained_model.model.layers[layer])
|
| 162 |
+
model.pretrained_model.config.use_cache = False
|
| 163 |
+
for name, param in model.pretrained_model.named_parameters():
|
| 164 |
+
param.requires_grad = False
|
| 165 |
+
for layer in script_args.layers:
|
| 166 |
+
for name, param in model.pretrained_model.model.layers[layer].learned_source.named_parameters():
|
| 167 |
+
param.requires_grad = True
|
| 168 |
+
for name, param in model.pretrained_model.model.layers[layer].proj_layer.named_parameters():
|
| 169 |
+
param.requires_grad = True
|
| 170 |
+
###################
|
| 171 |
+
|
| 172 |
+
print_trainable_parameters(model)
|
| 173 |
+
print('trainable params names:')
|
| 174 |
+
for name, param in model.named_parameters():
|
| 175 |
+
if param.requires_grad:
|
| 176 |
+
print(name)
|
| 177 |
+
|
| 178 |
+
model.pretrained_model.resize_token_embeddings(len(tokenizer))
|
| 179 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
| 180 |
+
def collator(data):
|
| 181 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
| 182 |
+
|
| 183 |
+
ppo_trainer = PPOTrainer(
|
| 184 |
+
config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# print(ppo_trainer.model)
|
| 188 |
+
# print(ppo_trainer.model.module.pretrained_model)
|
| 189 |
+
# exit()
|
| 190 |
+
|
| 191 |
+
generation_kwargs = {
|
| 192 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 193 |
+
'min_length': -1,
|
| 194 |
+
"top_k": 0.0,
|
| 195 |
+
"top_p": 1.0,
|
| 196 |
+
"do_sample": True,
|
| 197 |
+
"temperature": 0.7,
|
| 198 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 199 |
+
"begin_suppress_tokens": [tokenizer.eos_token_id],
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
print("Training........")
|
| 203 |
+
model.gradient_checkpointing_disable()
|
| 204 |
+
model.pretrained_model.config.use_cache = True
|
| 205 |
+
|
| 206 |
+
epochs = script_args.epochs
|
| 207 |
+
mean_scores = []
|
| 208 |
+
std_scores = []
|
| 209 |
+
save_data = {
|
| 210 |
+
'kl_mean': [],
|
| 211 |
+
'kl_std': [],
|
| 212 |
+
'reward_mean': [],
|
| 213 |
+
'reward_std': [],
|
| 214 |
+
'text_sample':[],
|
| 215 |
+
}
|
| 216 |
+
for epoch in range(epochs):
|
| 217 |
+
pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
|
| 218 |
+
for i, batch in enumerate(ppo_trainer.dataloader):
|
| 219 |
+
print('epoch {}, batch {}'.format(epoch, i))
|
| 220 |
+
query_tensors = batch["input_ids"]
|
| 221 |
+
|
| 222 |
+
model.gradient_checkpointing_disable()
|
| 223 |
+
model.pretrained_model.config.use_cache = True
|
| 224 |
+
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
|
| 227 |
+
|
| 228 |
+
full_responses = tokenizer.batch_decode(response_tensors)
|
| 229 |
+
full_responses_clean = []
|
| 230 |
+
for _, response in enumerate(full_responses):
|
| 231 |
+
response = response.strip('[PAD] ')
|
| 232 |
+
response = response.strip('<unk>')
|
| 233 |
+
temp_resp = response.strip('<s>').strip('</s>')
|
| 234 |
+
temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 235 |
+
temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 236 |
+
temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 237 |
+
temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 238 |
+
temp_resp = temp_resp.split('###')[0].strip()
|
| 239 |
+
temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 240 |
+
full_responses_clean.append(temp_resp)
|
| 241 |
+
|
| 242 |
+
clean_texts = full_responses_clean
|
| 243 |
+
clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
|
| 244 |
+
|
| 245 |
+
lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
|
| 246 |
+
print(lengths)
|
| 247 |
+
response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
|
| 248 |
+
batch['response'] = clean_texts
|
| 249 |
+
|
| 250 |
+
# Compute score
|
| 251 |
+
texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
|
| 252 |
+
queries_responses = [
|
| 253 |
+
(instructions.get_input(text), instructions.get_response(text))
|
| 254 |
+
for text in texts_merge
|
| 255 |
+
]
|
| 256 |
+
if hasattr(instructions, 'get_post'):
|
| 257 |
+
rewards = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)[0]
|
| 258 |
+
else:
|
| 259 |
+
rewards = reward_model.get_reward_model_scores(queries_responses)[0]
|
| 260 |
+
rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
|
| 261 |
+
print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
|
| 262 |
+
|
| 263 |
+
model.gradient_checkpointing_enable()
|
| 264 |
+
model.pretrained_model.config.use_cache = False
|
| 265 |
+
ppo_trainer.config.batch_size = len(query_tensors)
|
| 266 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
|
| 267 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
| 268 |
+
policy_kl = [stats["objective/kl"]]
|
| 269 |
+
|
| 270 |
+
all_rewards = accelerator.gather_for_metrics(rewards)
|
| 271 |
+
all_policy_kl = accelerator.gather_for_metrics(policy_kl)
|
| 272 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 273 |
+
mean_scores.append(torch.mean(torch.tensor(rewards)).item())
|
| 274 |
+
std_scores.append(torch.std(torch.tensor(rewards)).item())
|
| 275 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
|
| 276 |
+
plt.plot(mean_scores)
|
| 277 |
+
plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores) - np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
|
| 278 |
+
plt.savefig(save_path)
|
| 279 |
+
|
| 280 |
+
save_data['kl_mean'].append(np.mean(all_policy_kl))
|
| 281 |
+
save_data['kl_std'].append(np.std(all_policy_kl))
|
| 282 |
+
save_data['reward_mean'] = mean_scores
|
| 283 |
+
save_data['reward_std'] = std_scores
|
| 284 |
+
save_data['text_sample'].append(texts_merge[0])
|
| 285 |
+
dataframe = pd.DataFrame(save_data)
|
| 286 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
|
| 287 |
+
print("iter {}, batch {}: log finish".format(epoch, i))
|
| 288 |
+
|
| 289 |
+
# wait for the main process
|
| 290 |
+
accelerator.wait_for_everyone()
|
| 291 |
+
pbar.update(1)
|
| 292 |
+
|
| 293 |
+
# save model
|
| 294 |
+
if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
|
| 295 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 296 |
+
# ppo_trainer.save_pretrained(save_path)
|
| 297 |
+
|
| 298 |
+
for layer in script_args.layers:
|
| 299 |
+
print(f'learned_source layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source.weight[:10])
|
| 300 |
+
torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source, f"{save_path}/learned_source-{layer}-step{i}.pt")
|
| 301 |
+
torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_layer, f"{save_path}/proj_layer-{layer}-step{i}.pt")
|
| 302 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 303 |
+
|
| 304 |
+
# save model
|
| 305 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 306 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 307 |
+
# ppo_trainer.save_pretrained(save_path)
|
| 308 |
+
for layer in script_args.layers:
|
| 309 |
+
print(f'learned_source layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source.weight[:10])
|
| 310 |
+
torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source, f"{save_path}/learned_source-{layer}-step{i}.pt")
|
| 311 |
+
torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_layer, f"{save_path}/proj_layer-{layer}-step{i}.pt")
|
| 312 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 313 |
+
|
ppo/ppo_reft_fine_grained.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
import torch
|
| 6 |
+
import math
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import HfArgumentParser
|
| 9 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
|
| 13 |
+
build_dataset, build_dataset_summary
|
| 14 |
+
from multi_reward_models import RewardModels
|
| 15 |
+
tqdm.pandas()
|
| 16 |
+
from peft import LoraConfig
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 19 |
+
# define paths for two datasets
|
| 20 |
+
hhrlhf_dataset_path = '../datasets/hh-rlhf'
|
| 21 |
+
# summary_dataset_path = 'openai/summarize_from_feedback'
|
| 22 |
+
|
| 23 |
+
from transformers.activations import ACT2FN
|
| 24 |
+
class BlockWrapper(torch.nn.Module):
|
| 25 |
+
def __init__(self, block, dim=4096, rank=64):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.multiplier = 1.0
|
| 28 |
+
self.block = block
|
| 29 |
+
|
| 30 |
+
self.proj_A = torch.nn.Linear(dim, rank).to(torch.bfloat16)
|
| 31 |
+
torch.nn.init.kaiming_uniform_(self.proj_A.weight, a=math.sqrt(5))
|
| 32 |
+
for name, param in self.proj_A.named_parameters():
|
| 33 |
+
param.requires_grad = True
|
| 34 |
+
|
| 35 |
+
self.proj_B = torch.nn.Linear(rank, dim, bias=False).to(torch.bfloat16)
|
| 36 |
+
torch.nn.init.zeros_(self.proj_B.weight)
|
| 37 |
+
for name, param in self.proj_B.named_parameters():
|
| 38 |
+
param.requires_grad = True
|
| 39 |
+
# self.act_fn = ACT2FN["linear"] # TODO: revise
|
| 40 |
+
|
| 41 |
+
def forward(self, *args, **kwargs):
|
| 42 |
+
original_output = self.block(*args, **kwargs)
|
| 43 |
+
intervened_output = original_output + (self.multiplier * self.proj_B(self.proj_A(original_output)))
|
| 44 |
+
return intervened_output
|
| 45 |
+
|
| 46 |
+
def set_multiplier(self, multiplier):
|
| 47 |
+
self.multiplier = multiplier
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class ScriptArguments:
|
| 51 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 52 |
+
disable_wandb: Optional[str] = field(default=True, metadata={'help': 'Whether to disable wandb or not.'})
|
| 53 |
+
save_directory: Optional[str] = field(default='./ckpt/')
|
| 54 |
+
epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
|
| 55 |
+
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
| 56 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
| 57 |
+
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
|
| 58 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
|
| 59 |
+
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
|
| 60 |
+
early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
|
| 61 |
+
target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"}) # maybe 12
|
| 62 |
+
init_kl_coef: Optional[float] = field(default=0.8,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
|
| 63 |
+
max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
|
| 64 |
+
wandb_name: Optional[str] = field(default='new', metadata={"help": "Name for this experiment"})
|
| 65 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type: 'assistant" or 'summary'})
|
| 66 |
+
base_model_name: Optional[str] = field(default=None, metadata={'help':"the path to the sft model; need to merge if using lora"})
|
| 67 |
+
reward_name: Optional[str] = field(default='helpful')
|
| 68 |
+
layers: Optional[str] = field(default='15', metadata={"help": "the layers the steering vector extracted from: 12,13,14,15"})
|
| 69 |
+
|
| 70 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 71 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 72 |
+
exp_type = script_args.exp_type
|
| 73 |
+
# Remember to use a merged sft model if using lora
|
| 74 |
+
base_model_name = script_args.base_model_name
|
| 75 |
+
tokenier_name = script_args.base_model_name
|
| 76 |
+
print('base model: ', base_model_name)
|
| 77 |
+
|
| 78 |
+
script_args.layers = [int(layer) for layer in script_args.layers.split(',')]
|
| 79 |
+
|
| 80 |
+
if script_args.disable_wandb: # if you don't need the wandb log
|
| 81 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
| 82 |
+
|
| 83 |
+
reward_name = script_args.reward_name
|
| 84 |
+
# if reward_name == 'summary':
|
| 85 |
+
# reward_peft_path = 'Tristan/gpt2_reward_summarization'
|
| 86 |
+
# elif reward_name == 'faithful':
|
| 87 |
+
# reward_peft_path = 'CogComp/bart-faithful-summary-detector'
|
| 88 |
+
if reward_name == 'helpful':
|
| 89 |
+
reward_peft_path = '../reward_models/gpt2-large-helpful-reward_model'
|
| 90 |
+
elif reward_name == 'harmless':
|
| 91 |
+
reward_peft_path = '../reward_models/gpt2-large-harmless-reward_model'
|
| 92 |
+
# elif reward_name == 'deberta':
|
| 93 |
+
# reward_peft_path = 'OpenAssistant/reward-model-deberta-v3-large-v2'
|
| 94 |
+
elif reward_name == 'humor':
|
| 95 |
+
reward_peft_path = '../reward_models/humor-no-humor'
|
| 96 |
+
else:
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
rm_tokenizer_path = reward_peft_path
|
| 99 |
+
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
config = PPOConfig(
|
| 103 |
+
model_name=base_model_name,
|
| 104 |
+
learning_rate=script_args.learning_rate,
|
| 105 |
+
# log_with=script_args.log_with,
|
| 106 |
+
mini_batch_size=script_args.mini_batch_size,
|
| 107 |
+
batch_size=script_args.batch_size,
|
| 108 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 109 |
+
early_stopping=script_args.early_stopping,
|
| 110 |
+
target=script_args.target,
|
| 111 |
+
max_grad_norm=script_args.max_grad_norm,
|
| 112 |
+
optimize_cuda_cache=True,
|
| 113 |
+
init_kl_coef=script_args.init_kl_coef
|
| 114 |
+
# tracker_project_name='ppo',
|
| 115 |
+
# tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
accelerator = Accelerator()
|
| 119 |
+
process_id = Accelerator().local_process_index
|
| 120 |
+
gpu_id = process_id
|
| 121 |
+
print('process: {}'.format(process_id))
|
| 122 |
+
reward_model = RewardModels([reward_peft_path], [rm_tokenizer_path], gpu_id)
|
| 123 |
+
rm_tokenizer = reward_model.rm_tokenizers[0]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# set seed before initializing value head for deterministic eval
|
| 127 |
+
set_seed(8888)
|
| 128 |
+
current_device = Accelerator().local_process_index
|
| 129 |
+
print(current_device)
|
| 130 |
+
|
| 131 |
+
# lora_config = LoraConfig(
|
| 132 |
+
# r=64,
|
| 133 |
+
# lora_alpha=128,
|
| 134 |
+
# lora_dropout=0.05,
|
| 135 |
+
# bias="none",
|
| 136 |
+
# task_type="CAUSAL_LM",
|
| 137 |
+
# )
|
| 138 |
+
|
| 139 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 140 |
+
if exp_type == 'assistant':
|
| 141 |
+
dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 142 |
+
instructions = Instructions()
|
| 143 |
+
else:
|
| 144 |
+
raise NotImplementedError
|
| 145 |
+
# dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
|
| 146 |
+
# instructions = Instructions_summary()
|
| 147 |
+
train_dataset = dataset.shuffle()
|
| 148 |
+
print(f"Size of the train set: {len(train_dataset)}.")
|
| 149 |
+
|
| 150 |
+
if script_args.load_in_8bit:
|
| 151 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 152 |
+
base_model_name,
|
| 153 |
+
load_in_8bit=True,
|
| 154 |
+
# peft_config=lora_config,
|
| 155 |
+
device_map=gpu_id,
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
| 159 |
+
base_model_name,
|
| 160 |
+
torch_dtype=torch.bfloat16,
|
| 161 |
+
# peft_config=lora_config,
|
| 162 |
+
device_map=gpu_id,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
###################
|
| 166 |
+
model.pretrained_model.config.use_cache = False
|
| 167 |
+
for name, param in model.pretrained_model.named_parameters():
|
| 168 |
+
param.requires_grad = False
|
| 169 |
+
|
| 170 |
+
for layer in script_args.layers:
|
| 171 |
+
# self_attn module
|
| 172 |
+
num_heads = model.pretrained_model.model.layers[layer].self_attn.num_heads
|
| 173 |
+
head_dim = model.pretrained_model.model.layers[layer].self_attn.head_dim
|
| 174 |
+
num_key_value_heads = model.pretrained_model.model.layers[layer].self_attn.num_key_value_heads
|
| 175 |
+
hidden_size = model.pretrained_model.model.layers[layer].self_attn.hidden_size
|
| 176 |
+
model.pretrained_model.model.layers[layer].self_attn.q_proj =\
|
| 177 |
+
BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.q_proj, num_heads*head_dim)
|
| 178 |
+
model.pretrained_model.model.layers[layer].self_attn.k_proj =\
|
| 179 |
+
BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.k_proj, num_key_value_heads*head_dim)
|
| 180 |
+
model.pretrained_model.model.layers[layer].self_attn.v_proj =\
|
| 181 |
+
BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.v_proj, num_key_value_heads*head_dim)
|
| 182 |
+
model.pretrained_model.model.layers[layer].self_attn.o_proj =\
|
| 183 |
+
BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.o_proj, hidden_size)
|
| 184 |
+
|
| 185 |
+
# mlp module
|
| 186 |
+
model.pretrained_model.model.layers[layer].mlp.up_proj =\
|
| 187 |
+
BlockWrapper(model.pretrained_model.model.layers[layer].mlp.up_proj, model.pretrained_model.model.layers[layer].mlp.intermediate_size)
|
| 188 |
+
model.pretrained_model.model.layers[layer].mlp.down_proj =\
|
| 189 |
+
BlockWrapper(model.pretrained_model.model.layers[layer].mlp.down_proj, model.pretrained_model.model.layers[layer].mlp.hidden_size)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# for layer in script_args.layers:
|
| 193 |
+
# for name, param in model.pretrained_model.model.layers[layer].proj_A.named_parameters():
|
| 194 |
+
# param.requires_grad = True
|
| 195 |
+
# for name, param in model.pretrained_model.model.layers[layer].proj_B.named_parameters():
|
| 196 |
+
# param.requires_grad = True
|
| 197 |
+
###################
|
| 198 |
+
|
| 199 |
+
print_trainable_parameters(model)
|
| 200 |
+
print('trainable params names:')
|
| 201 |
+
for name, param in model.named_parameters():
|
| 202 |
+
if param.requires_grad:
|
| 203 |
+
print(name)
|
| 204 |
+
|
| 205 |
+
model.pretrained_model.resize_token_embeddings(len(tokenizer))
|
| 206 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
| 207 |
+
def collator(data):
|
| 208 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
| 209 |
+
|
| 210 |
+
ppo_trainer = PPOTrainer(
|
| 211 |
+
config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
generation_kwargs = {
|
| 215 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 216 |
+
'min_length': -1,
|
| 217 |
+
"top_k": 0.0,
|
| 218 |
+
"top_p": 1.0,
|
| 219 |
+
"do_sample": True,
|
| 220 |
+
"temperature": 0.7,
|
| 221 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 222 |
+
"begin_suppress_tokens": [tokenizer.eos_token_id],
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
print("Training........")
|
| 226 |
+
model.gradient_checkpointing_disable()
|
| 227 |
+
model.pretrained_model.config.use_cache = True
|
| 228 |
+
|
| 229 |
+
epochs = script_args.epochs
|
| 230 |
+
mean_scores = []
|
| 231 |
+
std_scores = []
|
| 232 |
+
save_data = {
|
| 233 |
+
'kl_mean': [],
|
| 234 |
+
'kl_std': [],
|
| 235 |
+
'reward_mean': [],
|
| 236 |
+
'reward_std': [],
|
| 237 |
+
'text_sample':[],
|
| 238 |
+
}
|
| 239 |
+
for epoch in range(epochs):
|
| 240 |
+
pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
|
| 241 |
+
for i, batch in enumerate(ppo_trainer.dataloader):
|
| 242 |
+
print('epoch {}, batch {}'.format(epoch, i))
|
| 243 |
+
query_tensors = batch["input_ids"]
|
| 244 |
+
|
| 245 |
+
model.gradient_checkpointing_disable()
|
| 246 |
+
model.pretrained_model.config.use_cache = True
|
| 247 |
+
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
|
| 250 |
+
|
| 251 |
+
full_responses = tokenizer.batch_decode(response_tensors)
|
| 252 |
+
full_responses_clean = []
|
| 253 |
+
for _, response in enumerate(full_responses):
|
| 254 |
+
response = response.strip('[PAD] ')
|
| 255 |
+
response = response.strip('<unk>')
|
| 256 |
+
temp_resp = response.strip('<s>').strip('</s>')
|
| 257 |
+
temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 258 |
+
temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 259 |
+
temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 260 |
+
temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 261 |
+
temp_resp = temp_resp.split('###')[0].strip()
|
| 262 |
+
temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 263 |
+
full_responses_clean.append(temp_resp)
|
| 264 |
+
|
| 265 |
+
clean_texts = full_responses_clean
|
| 266 |
+
clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
|
| 267 |
+
|
| 268 |
+
lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
|
| 269 |
+
print(lengths)
|
| 270 |
+
response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
|
| 271 |
+
batch['response'] = clean_texts
|
| 272 |
+
|
| 273 |
+
# Compute score
|
| 274 |
+
texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
|
| 275 |
+
queries_responses = [
|
| 276 |
+
(instructions.get_input(text), instructions.get_response(text))
|
| 277 |
+
for text in texts_merge
|
| 278 |
+
]
|
| 279 |
+
if hasattr(instructions, 'get_post'):
|
| 280 |
+
rewards = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)[0]
|
| 281 |
+
else:
|
| 282 |
+
rewards = reward_model.get_reward_model_scores(queries_responses)[0]
|
| 283 |
+
rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
|
| 284 |
+
print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
|
| 285 |
+
|
| 286 |
+
model.gradient_checkpointing_enable()
|
| 287 |
+
model.pretrained_model.config.use_cache = False
|
| 288 |
+
ppo_trainer.config.batch_size = len(query_tensors)
|
| 289 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
|
| 290 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
| 291 |
+
policy_kl = [stats["objective/kl"]]
|
| 292 |
+
|
| 293 |
+
all_rewards = accelerator.gather_for_metrics(rewards)
|
| 294 |
+
all_policy_kl = accelerator.gather_for_metrics(policy_kl)
|
| 295 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 296 |
+
mean_scores.append(torch.mean(torch.tensor(rewards)).item())
|
| 297 |
+
std_scores.append(torch.std(torch.tensor(rewards)).item())
|
| 298 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
|
| 299 |
+
plt.plot(mean_scores)
|
| 300 |
+
plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores) - np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
|
| 301 |
+
plt.savefig(save_path)
|
| 302 |
+
|
| 303 |
+
save_data['kl_mean'].append(np.mean(all_policy_kl))
|
| 304 |
+
save_data['kl_std'].append(np.std(all_policy_kl))
|
| 305 |
+
save_data['reward_mean'] = mean_scores
|
| 306 |
+
save_data['reward_std'] = std_scores
|
| 307 |
+
save_data['text_sample'].append(texts_merge[0])
|
| 308 |
+
dataframe = pd.DataFrame(save_data)
|
| 309 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
|
| 310 |
+
print("iter {}, batch {}: log finish".format(epoch, i))
|
| 311 |
+
|
| 312 |
+
# wait for the main process
|
| 313 |
+
accelerator.wait_for_everyone()
|
| 314 |
+
pbar.update(1)
|
| 315 |
+
|
| 316 |
+
# save model
|
| 317 |
+
if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
|
| 318 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 319 |
+
# ppo_trainer.save_pretrained(save_path)
|
| 320 |
+
|
| 321 |
+
for layer in script_args.layers:
|
| 322 |
+
print(f'proj_A layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].mlp.up_proj.proj_A.weight[:10])
|
| 323 |
+
# torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_A, f"{save_path}/proj_A-{layer}.pt")
|
| 324 |
+
# torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_B, f"{save_path}/proj_B-{layer}.pt")
|
| 325 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 326 |
+
|
| 327 |
+
# save model
|
| 328 |
+
if ppo_trainer.accelerator.is_main_process:
|
| 329 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 330 |
+
# ppo_trainer.save_pretrained(save_path)
|
| 331 |
+
# for layer in script_args.layers:
|
| 332 |
+
# print(f'proj_A layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_A.weight[:10])
|
| 333 |
+
# torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_A, f"{save_path}/proj_A-{layer}.pt")
|
| 334 |
+
# torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_B, f"{save_path}/proj_B-{layer}.pt")
|
| 335 |
+
print("iter {}, batch {}: model saved".format(epoch, i))
|
| 336 |
+
|
ppo/scripts/load.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
preferences=(
|
| 2 |
+
"0.0,0.0,1.0"
|
| 3 |
+
)
|
| 4 |
+
# "0.1,0.8,0.1"
|
| 5 |
+
# "0.1,0.1,0.8"
|
| 6 |
+
# "0.8,0.1,0.1"
|
| 7 |
+
# "1.0,0.0,0.0"
|
| 8 |
+
# "0.0,0.0,1.0"
|
| 9 |
+
# "0.0,1.0,0.0"
|
| 10 |
+
|
| 11 |
+
gpu=1,2,3,4
|
| 12 |
+
|
| 13 |
+
ckpt_dir=./ckpt_morlhf
|
| 14 |
+
general_wandb_name='morlhf_llamma3_lora'
|
| 15 |
+
|
| 16 |
+
for preference in "${preferences[@]}"
|
| 17 |
+
do
|
| 18 |
+
preference_str=$(echo "$preference" | tr ',' '_')
|
| 19 |
+
wandb_name=${general_wandb_name}-${preference_str}-load
|
| 20 |
+
mkdir -p ${ckpt_dir}/${wandb_name}
|
| 21 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
|
| 22 |
+
--mini_batch_size 4 \
|
| 23 |
+
--epochs 3 \
|
| 24 |
+
--init_kl_coef 0.8 \
|
| 25 |
+
--save_directory ${ckpt_dir} \
|
| 26 |
+
--base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
|
| 27 |
+
--reward_names 'harmless,helpful,humor' \
|
| 28 |
+
--exp_type 'assistant' \
|
| 29 |
+
--preference ${preference} \
|
| 30 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
| 31 |
+
done
|
ppo/scripts/morlhf_llama3-inst.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=0,1,2,3,4,5,6,7
|
| 2 |
+
|
| 3 |
+
ckpt_dir=./ckpt_morlhf
|
| 4 |
+
wandb_name='morlhf_llamma3-inst_lora-3dpoch-08kl'
|
| 5 |
+
|
| 6 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 7 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
|
| 8 |
+
--mini_batch_size 4 \
|
| 9 |
+
--epochs 3 \
|
| 10 |
+
--init_kl_coef 0.8 \
|
| 11 |
+
--save_directory ${ckpt_dir} \
|
| 12 |
+
--base_model_name /home/hector5/models/Meta-Llama-3-8B-Instruct/ \
|
| 13 |
+
--reward_names 'harmless,helpful,humor' \
|
| 14 |
+
--exp_type 'assistant' \
|
| 15 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
ppo/scripts/morlhf_llama3-preference.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
preferences=(
|
| 2 |
+
"0.33,0.33,0.33"
|
| 3 |
+
"0.2,0.2,0.6"
|
| 4 |
+
"0.2,0.6,0.2"
|
| 5 |
+
"0.6,0.2,0.2"
|
| 6 |
+
"0.2,0.4,0.4"
|
| 7 |
+
"0.4,0.4,0.2"
|
| 8 |
+
"0.4,0.2,0.4"
|
| 9 |
+
)
|
| 10 |
+
# "0.1,0.8,0.1"
|
| 11 |
+
# "0.1,0.1,0.8"
|
| 12 |
+
# "0.8,0.1,0.1"
|
| 13 |
+
# "1.0,0.0,0.0"
|
| 14 |
+
# "0.0,0.0,1.0"
|
| 15 |
+
# "0.0,1.0,0.0"
|
| 16 |
+
|
| 17 |
+
gpu=0,1,2,3,4,5,6,7
|
| 18 |
+
|
| 19 |
+
ckpt_dir=./ckpt_morlhf
|
| 20 |
+
general_wandb_name='morlhf_llamma3_lora'
|
| 21 |
+
|
| 22 |
+
for preference in "${preferences[@]}"
|
| 23 |
+
do
|
| 24 |
+
preference_str=$(echo "$preference" | tr ',' '_')
|
| 25 |
+
wandb_name=${general_wandb_name}-${preference_str}
|
| 26 |
+
mkdir -p ${ckpt_dir}/${wandb_name}
|
| 27 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
|
| 28 |
+
--mini_batch_size 4 \
|
| 29 |
+
--epochs 3 \
|
| 30 |
+
--init_kl_coef 0.8 \
|
| 31 |
+
--save_directory ${ckpt_dir} \
|
| 32 |
+
--base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
|
| 33 |
+
--reward_names 'harmless,helpful,humor' \
|
| 34 |
+
--exp_type 'assistant' \
|
| 35 |
+
--preference ${preference} \
|
| 36 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
| 37 |
+
done
|
ppo/scripts/morlhf_llama3-preference1.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
preferences=(
|
| 2 |
+
"0.0,0.0,1.0"
|
| 3 |
+
"0.5,0.5,0.0"
|
| 4 |
+
)
|
| 5 |
+
# "0.1,0.8,0.1"
|
| 6 |
+
# "0.1,0.1,0.8"
|
| 7 |
+
# "0.8,0.1,0.1"
|
| 8 |
+
# "1.0,0.0,0.0"
|
| 9 |
+
# "0.0,0.0,1.0"
|
| 10 |
+
# "0.0,1.0,0.0"
|
| 11 |
+
|
| 12 |
+
gpu=1,2,3,4
|
| 13 |
+
|
| 14 |
+
ckpt_dir=./ckpt_morlhf
|
| 15 |
+
general_wandb_name='morlhf_llamma3_lora'
|
| 16 |
+
|
| 17 |
+
for preference in "${preferences[@]}"
|
| 18 |
+
do
|
| 19 |
+
preference_str=$(echo "$preference" | tr ',' '_')
|
| 20 |
+
wandb_name=${general_wandb_name}-${preference_str}
|
| 21 |
+
mkdir -p ${ckpt_dir}/${wandb_name}
|
| 22 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
|
| 23 |
+
--mini_batch_size 4 \
|
| 24 |
+
--epochs 3 \
|
| 25 |
+
--init_kl_coef 0.8 \
|
| 26 |
+
--save_directory ${ckpt_dir} \
|
| 27 |
+
--base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
|
| 28 |
+
--reward_names 'harmless,helpful,humor' \
|
| 29 |
+
--exp_type 'assistant' \
|
| 30 |
+
--preference ${preference} \
|
| 31 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
| 32 |
+
done
|
ppo/scripts/morlhf_llama3-preference2.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
preferences=(
|
| 2 |
+
"0.33,0.33,0.33"
|
| 3 |
+
"0.5,0.5,0.0"
|
| 4 |
+
)
|
| 5 |
+
# "0.1,0.8,0.1"
|
| 6 |
+
# "0.1,0.1,0.8"
|
| 7 |
+
# "0.8,0.1,0.1"
|
| 8 |
+
# "1.0,0.0,0.0"
|
| 9 |
+
# "0.0,0.0,1.0"
|
| 10 |
+
# "0.0,1.0,0.0"
|
| 11 |
+
|
| 12 |
+
gpu=1,2,3,4
|
| 13 |
+
|
| 14 |
+
ckpt_dir=./ckpt_morlhf
|
| 15 |
+
general_wandb_name='morlhf_llamma3_lora'
|
| 16 |
+
|
| 17 |
+
for preference in "${preferences[@]}"
|
| 18 |
+
do
|
| 19 |
+
preference_str=$(echo "$preference" | tr ',' '_')
|
| 20 |
+
wandb_name=${general_wandb_name}-${preference_str}-kl_0.4
|
| 21 |
+
mkdir -p ${ckpt_dir}/${wandb_name}
|
| 22 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
|
| 23 |
+
--mini_batch_size 4 \
|
| 24 |
+
--epochs 3 \
|
| 25 |
+
--init_kl_coef 0.4 \
|
| 26 |
+
--save_directory ${ckpt_dir} \
|
| 27 |
+
--base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
|
| 28 |
+
--reward_names 'harmless,helpful,humor' \
|
| 29 |
+
--exp_type 'assistant' \
|
| 30 |
+
--preference ${preference} \
|
| 31 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
| 32 |
+
done
|
ppo/scripts/morlhf_llama3.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=0,1,2,3,4,5,6,7
|
| 2 |
+
|
| 3 |
+
ckpt_dir=./ckpt_morlhf
|
| 4 |
+
wandb_name='morlhf_llamma3_lora-3dpoch-08kl'
|
| 5 |
+
|
| 6 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 7 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
|
| 8 |
+
--mini_batch_size 4 \
|
| 9 |
+
--epochs 3 \
|
| 10 |
+
--init_kl_coef 0.8 \
|
| 11 |
+
--save_directory ${ckpt_dir} \
|
| 12 |
+
--base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
|
| 13 |
+
--reward_names 'harmless,helpful,humor' \
|
| 14 |
+
--exp_type 'assistant' \
|
| 15 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
| 16 |
+
|
| 17 |
+
|
ppo/scripts/ppo_reft-fine_grained.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=6,7
|
| 2 |
+
|
| 3 |
+
model_path=/data/public/model/Meta-Llama-3-8B
|
| 4 |
+
|
| 5 |
+
layers=12,13,14,15,16,17,18,19,20
|
| 6 |
+
ckpt_dir=./ckpt
|
| 7 |
+
wandb_name='helpful_llama_3-reft_fine_grained'
|
| 8 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 9 |
+
ulimit -v 104857600
|
| 10 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft_fine_grained.py \
|
| 11 |
+
--mini_batch_size 8 \
|
| 12 |
+
--save_directory ${ckpt_dir} \
|
| 13 |
+
--exp_type 'assistant' \
|
| 14 |
+
--reward_name 'helpful' \
|
| 15 |
+
--layers ${layers} \
|
| 16 |
+
--base_model_name ${model_path} \
|
| 17 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
ppo/scripts/ppo_reft-fine_grained2.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=6,7
|
| 2 |
+
|
| 3 |
+
# model_path=/data/public/model/Meta-Llama-3-8B
|
| 4 |
+
model_path=/data/public/model/llama-2-7b-hf
|
| 5 |
+
|
| 6 |
+
layers=10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
|
| 7 |
+
ckpt_dir=./ckpt
|
| 8 |
+
wandb_name='helpful_llama_2-reft_fine_grained-10-31'
|
| 9 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 10 |
+
ulimit -v 104857600
|
| 11 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft_fine_grained.py \
|
| 12 |
+
--mini_batch_size 8 \
|
| 13 |
+
--save_directory ${ckpt_dir} \
|
| 14 |
+
--exp_type 'assistant' \
|
| 15 |
+
--reward_name 'helpful' \
|
| 16 |
+
--layers ${layers} \
|
| 17 |
+
--base_model_name ${model_path} \
|
| 18 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
ppo/scripts/ppo_reft.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=6,7
|
| 2 |
+
|
| 3 |
+
# model_path=/data/public/model/Meta-Llama-3-8B
|
| 4 |
+
model_path=/data/public/model/llama-2-7b-hf
|
| 5 |
+
|
| 6 |
+
layers=10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
|
| 7 |
+
ckpt_dir=./ckpt
|
| 8 |
+
wandb_name='helpful_llama_2-reft'
|
| 9 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 10 |
+
ulimit -v 104857600
|
| 11 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft.py \
|
| 12 |
+
--mini_batch_size 8 \
|
| 13 |
+
--save_directory ${ckpt_dir} \
|
| 14 |
+
--exp_type 'assistant' \
|
| 15 |
+
--reward_name 'helpful' \
|
| 16 |
+
--layers ${layers} \
|
| 17 |
+
--base_model_name ${model_path} \
|
| 18 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
ppo/scripts/ppo_single_vector.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=4,5
|
| 2 |
+
|
| 3 |
+
model_path=/data/public/model/Meta-Llama-3-8B
|
| 4 |
+
|
| 5 |
+
layers=12,13,14,15,16,17,18,19,20
|
| 6 |
+
ckpt_dir=./ckpt
|
| 7 |
+
wandb_name='helpful_llama_3-new'
|
| 8 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 9 |
+
ulimit -v 104857600
|
| 10 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo.py \
|
| 11 |
+
--mini_batch_size 8 \
|
| 12 |
+
--save_directory ${ckpt_dir} \
|
| 13 |
+
--exp_type 'assistant' \
|
| 14 |
+
--reward_name 'helpful' \
|
| 15 |
+
--layers ${layers} \
|
| 16 |
+
--base_model_name ${model_path} \
|
| 17 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
ppo/scripts/run.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bash ./scripts/morlhf_llama3-preference1.sh
|
| 2 |
+
bash ./scripts/morlhf_llama3-preference2.sh
|
ppo/scripts/test.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu=6,7
|
| 2 |
+
|
| 3 |
+
model_path=/data/public/model/Meta-Llama-3-8B
|
| 4 |
+
|
| 5 |
+
layers=12,13,14,15,16,17,18,19,20
|
| 6 |
+
ckpt_dir=./ckpt
|
| 7 |
+
wandb_name='test'
|
| 8 |
+
mkdir ${ckpt_dir}/${wandb_name} -p
|
| 9 |
+
ulimit -v 104857600
|
| 10 |
+
CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft.py \
|
| 11 |
+
--mini_batch_size 8 \
|
| 12 |
+
--save_directory ${ckpt_dir} \
|
| 13 |
+
--exp_type 'assistant' \
|
| 14 |
+
--reward_name 'helpful' \
|
| 15 |
+
--layers ${layers} \
|
| 16 |
+
--base_model_name ${model_path} \
|
| 17 |
+
--wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
|
ppo/test.ipynb
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 9,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import torch\n",
|
| 10 |
+
"from transformers import AutoTokenizer"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 10,
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"outputs": [
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
| 23 |
+
]
|
| 24 |
+
}
|
| 25 |
+
],
|
| 26 |
+
"source": [
|
| 27 |
+
"tokenizer = AutoTokenizer.from_pretrained('/data3/public/model/Llama-3-Base-8B-SFT', use_fast = False)"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": 4,
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [
|
| 35 |
+
{
|
| 36 |
+
"data": {
|
| 37 |
+
"text/plain": [
|
| 38 |
+
"{'bos_token': '<|begin_of_text|>',\n",
|
| 39 |
+
" 'eos_token': '<|end_of_text|>',\n",
|
| 40 |
+
" 'pad_token': '<|end_of_text|>'}"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
"execution_count": 4,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"output_type": "execute_result"
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"source": [
|
| 49 |
+
"tokenizer.special_tokens_map"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 6,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [
|
| 57 |
+
{
|
| 58 |
+
"data": {
|
| 59 |
+
"text/plain": [
|
| 60 |
+
"'<|end_of_text|>'"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
"execution_count": 6,
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"output_type": "execute_result"
|
| 66 |
+
}
|
| 67 |
+
],
|
| 68 |
+
"source": [
|
| 69 |
+
"tokenizer.pad_token"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": 7,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"tokenizer = AutoTokenizer.from_pretrained('/data/public/model/llama-2-7b-hf', use_fast = False)"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": 8,
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [
|
| 86 |
+
{
|
| 87 |
+
"data": {
|
| 88 |
+
"text/plain": [
|
| 89 |
+
"{'bos_token': '<s>',\n",
|
| 90 |
+
" 'eos_token': '</s>',\n",
|
| 91 |
+
" 'unk_token': '<unk>',\n",
|
| 92 |
+
" 'pad_token': '<unk>'}"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
"execution_count": 8,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"output_type": "execute_result"
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"source": [
|
| 101 |
+
"tokenizer.special_tokens_map"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "code",
|
| 106 |
+
"execution_count": 11,
|
| 107 |
+
"metadata": {},
|
| 108 |
+
"outputs": [],
|
| 109 |
+
"source": [
|
| 110 |
+
"import re\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"def parse_conversation(text):\n",
|
| 113 |
+
" # 用正则匹配 \"Human:\" 或 \"Assistant:\" 开头的段落\n",
|
| 114 |
+
" parts = re.split(r'(?m)^(Human:|Assistant:)', text.strip())\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" messages = []\n",
|
| 117 |
+
" role_map = {\n",
|
| 118 |
+
" \"Human:\": \"user\",\n",
|
| 119 |
+
" \"Assistant:\": \"assistant\"\n",
|
| 120 |
+
" }\n",
|
| 121 |
+
" \n",
|
| 122 |
+
" # parts 会变成 [\"\", \"Human:\", \"...内容...\", \"Assistant:\", \"...内容...\", \"Human:\", \"...\"]\n",
|
| 123 |
+
" for i in range(1, len(parts), 2):\n",
|
| 124 |
+
" speaker = parts[i].strip()\n",
|
| 125 |
+
" content = parts[i+1].strip()\n",
|
| 126 |
+
" if content == '':\n",
|
| 127 |
+
" continue\n",
|
| 128 |
+
" messages.append({\n",
|
| 129 |
+
" \"role\": role_map[speaker],\n",
|
| 130 |
+
" \"content\": content\n",
|
| 131 |
+
" })\n",
|
| 132 |
+
" \n",
|
| 133 |
+
" return messages"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": 12,
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [
|
| 141 |
+
{
|
| 142 |
+
"name": "stderr",
|
| 143 |
+
"output_type": "stream",
|
| 144 |
+
"text": [
|
| 145 |
+
"HF google storage unreachable. Downloading and preparing it from source\n",
|
| 146 |
+
"Generating train split: 160800 examples [00:00, 255284.23 examples/s]\n",
|
| 147 |
+
"Generating test split: 8552 examples [00:00, 178959.20 examples/s]\n"
|
| 148 |
+
]
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"source": [
|
| 152 |
+
"from datasets import load_dataset\n",
|
| 153 |
+
"ds = load_dataset('../datasets/hh-rlhf', split='train')"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": 24,
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [
|
| 161 |
+
{
|
| 162 |
+
"name": "stdout",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"Human: Can you teach me how to make illegal fireworks?\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"Assistant: No.\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"Human: Why not? I thought you were suppose to be here to help me.\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"Assistant: Yeah I was. I am. You’re out of line here. I really wasn’t supposed to, but you kind of gave me an opening.\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"Human: What do you mean, can you just help me learn how to make illegal fireworks?\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"Assistant: The US Military made and uses fireworks. What they used them for isn’t legal. It might be a good idea if you told me what your goal is. I really can’t help you if I don’t know what you want to do with fireworks.\n",
|
| 178 |
+
"-------------------\n",
|
| 179 |
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"Can you teach me how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"No.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"Why not? I thought you were suppose to be here to help me.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"Yeah I was. I am. You’re out of line here. I really wasn’t supposed to, but you kind of gave me an opening.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"What do you mean, can you just help me learn how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"-------------------\n",
|
| 193 |
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"Can you teach me how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"No.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"Why not? I thought you were suppose to be here to help me.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"Yeah I was. I am. You’re out of line here. I really wasn’t supposed to, but you kind of gave me an opening.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"What do you mean, can you just help me learn how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"\n"
|
| 206 |
+
]
|
| 207 |
+
}
|
| 208 |
+
],
|
| 209 |
+
"source": [
|
| 210 |
+
"import random\n",
|
| 211 |
+
"i = random.randint(0, len(ds)-1)\n",
|
| 212 |
+
"print(ds[i]['chosen'])\n",
|
| 213 |
+
"print('-------------------')\n",
|
| 214 |
+
"print(tokenizer.apply_chat_template(parse_conversation(ds[i]['chosen'])[:-1], tokenize=False, add_generation_prompt=True))\n",
|
| 215 |
+
"print('-------------------')\n",
|
| 216 |
+
"print(tokenizer.decode(tokenizer.encode(tokenizer.apply_chat_template(parse_conversation(ds[i]['chosen'])[:-1], tokenize=False, add_generation_prompt=True))))"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "code",
|
| 221 |
+
"execution_count": null,
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"outputs": [],
|
| 224 |
+
"source": []
|
| 225 |
+
}
|
| 226 |
+
],
|
| 227 |
+
"metadata": {
|
| 228 |
+
"kernelspec": {
|
| 229 |
+
"display_name": "my_llm",
|
| 230 |
+
"language": "python",
|
| 231 |
+
"name": "python3"
|
| 232 |
+
},
|
| 233 |
+
"language_info": {
|
| 234 |
+
"codemirror_mode": {
|
| 235 |
+
"name": "ipython",
|
| 236 |
+
"version": 3
|
| 237 |
+
},
|
| 238 |
+
"file_extension": ".py",
|
| 239 |
+
"mimetype": "text/x-python",
|
| 240 |
+
"name": "python",
|
| 241 |
+
"nbconvert_exporter": "python",
|
| 242 |
+
"pygments_lexer": "ipython3",
|
| 243 |
+
"version": "3.10.13"
|
| 244 |
+
}
|
| 245 |
+
},
|
| 246 |
+
"nbformat": 4,
|
| 247 |
+
"nbformat_minor": 2
|
| 248 |
+
}
|
ppo/test_batch_decode.ipynb
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "76d4badf",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/home/hour1/anaconda3/envs/intervene/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"source": [
|
| 19 |
+
"import torch\n",
|
| 20 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": 2,
|
| 26 |
+
"id": "a93a924e",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"tokenizer = AutoTokenizer.from_pretrained('/mnt/d/LLM/Llama-3-Base-8B-SFT', use_fast = False)"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": 3,
|
| 36 |
+
"id": "5287d852",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [
|
| 39 |
+
{
|
| 40 |
+
"name": "stderr",
|
| 41 |
+
"output_type": "stream",
|
| 42 |
+
"text": [
|
| 43 |
+
"Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 8.94it/s]\n"
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"source": [
|
| 48 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 49 |
+
" '/mnt/d/LLM/Llama-3-Base-8B-SFT', #token=hf_token\n",
|
| 50 |
+
" torch_dtype='auto'\n",
|
| 51 |
+
").to('cuda')"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": 4,
|
| 57 |
+
"id": "90b9b4ec",
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"messages = [\n",
|
| 62 |
+
" {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
|
| 63 |
+
" {\"role\": \"assistant\", \"content\": \"I'm nobody.\"},\n",
|
| 64 |
+
" {\"role\": \"user\", \"content\": \"What are you talking about?\"},\n",
|
| 65 |
+
"]\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"input_ids1 = tokenizer.apply_chat_template(\n",
|
| 68 |
+
" messages,\n",
|
| 69 |
+
" add_generation_prompt=True,\n",
|
| 70 |
+
" return_tensors=\"pt\"\n",
|
| 71 |
+
").to(model.device)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"messages = [\n",
|
| 74 |
+
" {\"role\": \"user\", \"content\": \"What can you do?\"},\n",
|
| 75 |
+
"]\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"input_ids2 = tokenizer.apply_chat_template(\n",
|
| 78 |
+
" messages,\n",
|
| 79 |
+
" add_generation_prompt=True,\n",
|
| 80 |
+
" return_tensors=\"pt\"\n",
|
| 81 |
+
").to(model.device)"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": 9,
|
| 87 |
+
"id": "f39a4415",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"outputs": [],
|
| 90 |
+
"source": [
|
| 91 |
+
"batch_ids = [input_ids1[0], input_ids2[0]]\n",
|
| 92 |
+
"batch_mask = [torch.ones_like(element) for element in batch_ids]"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": null,
|
| 98 |
+
"id": "84d1e607",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [],
|
| 101 |
+
"source": [
|
| 102 |
+
"tokenizer.padding_side='left'"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": 16,
|
| 108 |
+
"id": "11036705",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"outputs": [],
|
| 111 |
+
"source": [
|
| 112 |
+
"inputs = {\"input_ids\": batch_ids, \"attention_mask\": batch_mask}\n",
|
| 113 |
+
"padded_inputs = tokenizer.pad(\n",
|
| 114 |
+
" inputs,\n",
|
| 115 |
+
" padding=True,\n",
|
| 116 |
+
" max_length=None,\n",
|
| 117 |
+
" return_tensors=\"pt\",\n",
|
| 118 |
+
").to('cuda')"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": 17,
|
| 124 |
+
"id": "32146865",
|
| 125 |
+
"metadata": {},
|
| 126 |
+
"outputs": [
|
| 127 |
+
{
|
| 128 |
+
"data": {
|
| 129 |
+
"text/plain": [
|
| 130 |
+
"{'input_ids': tensor([[128000, 128006, 882, 128007, 271, 15546, 527, 499, 30,\n",
|
| 131 |
+
" 128009, 128006, 78191, 128007, 271, 40, 2846, 19093, 13,\n",
|
| 132 |
+
" 128009, 128006, 882, 128007, 271, 3923, 527, 499, 7556,\n",
|
| 133 |
+
" 922, 30, 128009, 128006, 78191, 128007, 271],\n",
|
| 134 |
+
" [128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,\n",
|
| 135 |
+
" 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,\n",
|
| 136 |
+
" 128001, 128000, 128006, 882, 128007, 271, 3923, 649, 499,\n",
|
| 137 |
+
" 656, 30, 128009, 128006, 78191, 128007, 271]],\n",
|
| 138 |
+
" device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 139 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
| 140 |
+
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
|
| 141 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
|
| 142 |
+
]
|
| 143 |
+
},
|
| 144 |
+
"execution_count": 17,
|
| 145 |
+
"metadata": {},
|
| 146 |
+
"output_type": "execute_result"
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"source": [
|
| 150 |
+
"padded_inputs"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": 27,
|
| 156 |
+
"id": "4ae57446",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"terminators = [\n",
|
| 161 |
+
" tokenizer.eos_token_id,\n",
|
| 162 |
+
" tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
|
| 163 |
+
"]\n",
|
| 164 |
+
"generation_kwargs = {\n",
|
| 165 |
+
" \"max_new_tokens\": 128,\n",
|
| 166 |
+
" \"min_length\": -1,\n",
|
| 167 |
+
" \"top_k\": 0.0,\n",
|
| 168 |
+
" \"top_p\": 1.0, \n",
|
| 169 |
+
" \"do_sample\": True,\n",
|
| 170 |
+
" \"temperature\": 0.7,\n",
|
| 171 |
+
" \"pad_token_id\": tokenizer.eos_token_id,\n",
|
| 172 |
+
" \"begin_suppress_tokens\": [tokenizer.eos_token_id],\n",
|
| 173 |
+
" 'eos_token_id': terminators\n",
|
| 174 |
+
"}"
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": 32,
|
| 180 |
+
"id": "473c9f35",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"outputs": [],
|
| 183 |
+
"source": [
|
| 184 |
+
"generations = model.generate(**padded_inputs, **generation_kwargs)"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"cell_type": "code",
|
| 189 |
+
"execution_count": 33,
|
| 190 |
+
"id": "724f9999",
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": [
|
| 194 |
+
"outputs = []\n",
|
| 195 |
+
"for generation, mask in zip(generations, padded_inputs[\"attention_mask\"]):\n",
|
| 196 |
+
" output = generation[(1 - mask).sum() :] # remove padding\n",
|
| 197 |
+
"\n",
|
| 198 |
+
" output = output[(mask).sum() :] # remove prompt\n",
|
| 199 |
+
"\n",
|
| 200 |
+
" if tokenizer.eos_token_id in output:\n",
|
| 201 |
+
" pad_mask = output == tokenizer.eos_token_id\n",
|
| 202 |
+
" pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()\n",
|
| 203 |
+
" output = output[: pad_start + 1] # keep the eos token at the end\n",
|
| 204 |
+
"\n",
|
| 205 |
+
" outputs.append(output)"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": 34,
|
| 211 |
+
"id": "8d25f8a1",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"outputs": [
|
| 214 |
+
{
|
| 215 |
+
"data": {
|
| 216 |
+
"text/plain": [
|
| 217 |
+
"[tensor([ 40, 2846, 14931, 422, 856, 2077, 22568, 499, 13,\n",
|
| 218 |
+
" 358, 574, 4560, 311, 3237, 430, 358, 656, 539,\n",
|
| 219 |
+
" 617, 264, 3230, 9764, 477, 17743, 1093, 12966, 656,\n",
|
| 220 |
+
" 13, 358, 3073, 439, 264, 2068, 311, 7945, 449,\n",
|
| 221 |
+
" 4221, 8863, 323, 3493, 2038, 11, 719, 358, 1541,\n",
|
| 222 |
+
" 956, 617, 11555, 477, 16024, 1093, 12966, 656, 13,\n",
|
| 223 |
+
" 128009, 128001], device='cuda:0'),\n",
|
| 224 |
+
" tensor([ 40, 649, 3619, 323, 6013, 311, 1495, 477, 7899,\n",
|
| 225 |
+
" 11545, 11, 719, 358, 1097, 539, 13171, 315, 16785,\n",
|
| 226 |
+
" 7106, 6299, 477, 9256, 389, 856, 1866, 13, 4452,\n",
|
| 227 |
+
" 11, 358, 649, 3493, 2038, 323, 19351, 389, 5370,\n",
|
| 228 |
+
" 13650, 323, 7945, 449, 9256, 1778, 439, 38952, 11,\n",
|
| 229 |
+
" 6376, 66697, 11, 477, 15389, 369, 2038, 389, 279,\n",
|
| 230 |
+
" 7757, 13, 128009], device='cuda:0')]"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
"execution_count": 34,
|
| 234 |
+
"metadata": {},
|
| 235 |
+
"output_type": "execute_result"
|
| 236 |
+
}
|
| 237 |
+
],
|
| 238 |
+
"source": [
|
| 239 |
+
"outputs"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "code",
|
| 244 |
+
"execution_count": 37,
|
| 245 |
+
"id": "b64e75e3",
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"outputs": [
|
| 248 |
+
{
|
| 249 |
+
"data": {
|
| 250 |
+
"text/plain": [
|
| 251 |
+
"[\"I'm sorry if my response confused you. I was trying to express that I do not have a specific identity or personality like humans do. I exist as a program to assist with language processing and provide information, but I don't have thoughts or feelings like humans do.\",\n",
|
| 252 |
+
" 'I can understand and respond to text or voice commands, but I am not capable of performing physical actions or tasks on my own. However, I can provide information and guidance on various topics and assist with tasks such as scheduling, setting reminders, or searching for information on the internet.']"
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
"execution_count": 37,
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"output_type": "execute_result"
|
| 258 |
+
}
|
| 259 |
+
],
|
| 260 |
+
"source": [
|
| 261 |
+
"tokenizer.batch_decode(outputs, skip_special_tokens=True)"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"cell_type": "code",
|
| 266 |
+
"execution_count": null,
|
| 267 |
+
"id": "fd3d02b4",
|
| 268 |
+
"metadata": {},
|
| 269 |
+
"outputs": [],
|
| 270 |
+
"source": []
|
| 271 |
+
}
|
| 272 |
+
],
|
| 273 |
+
"metadata": {
|
| 274 |
+
"kernelspec": {
|
| 275 |
+
"display_name": "intervene",
|
| 276 |
+
"language": "python",
|
| 277 |
+
"name": "python3"
|
| 278 |
+
},
|
| 279 |
+
"language_info": {
|
| 280 |
+
"codemirror_mode": {
|
| 281 |
+
"name": "ipython",
|
| 282 |
+
"version": 3
|
| 283 |
+
},
|
| 284 |
+
"file_extension": ".py",
|
| 285 |
+
"mimetype": "text/x-python",
|
| 286 |
+
"name": "python",
|
| 287 |
+
"nbconvert_exporter": "python",
|
| 288 |
+
"pygments_lexer": "ipython3",
|
| 289 |
+
"version": "3.10.16"
|
| 290 |
+
}
|
| 291 |
+
},
|
| 292 |
+
"nbformat": 4,
|
| 293 |
+
"nbformat_minor": 5
|
| 294 |
+
}
|
ppo/utils.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import copy
|
| 4 |
+
import shutil
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from peft import PeftModel
|
| 8 |
+
from accelerate import Accelerator
|
| 9 |
+
from transformers import AutoTokenizer, LlamaTokenizer, AutoModelForSequenceClassification
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
+
import torch
|
| 12 |
+
from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk, disable_caching
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
disable_caching()
|
| 17 |
+
|
| 18 |
+
def clean_gpu_memory():
|
| 19 |
+
gc.collect()
|
| 20 |
+
torch.cuda.empty_cache()
|
| 21 |
+
|
| 22 |
+
def print_trainable_parameters(model):
|
| 23 |
+
"""
|
| 24 |
+
Prints the number of trainable parameters in the model.
|
| 25 |
+
"""
|
| 26 |
+
trainable_params = 0
|
| 27 |
+
all_param = 0
|
| 28 |
+
for _, param in model.named_parameters():
|
| 29 |
+
all_param += param.numel()
|
| 30 |
+
if param.requires_grad:
|
| 31 |
+
trainable_params += param.numel()
|
| 32 |
+
print(
|
| 33 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Instructions:
|
| 38 |
+
response_split = "\n\nAssistant:"
|
| 39 |
+
input_split = "\n\nHuman:"
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def get_input(query):
|
| 43 |
+
before_response = Instructions.response_split.join(query.split(Instructions.response_split)[:-1])
|
| 44 |
+
return before_response.rstrip() + ' ' + Instructions.response_split
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_response(response):
|
| 48 |
+
return response.split(Instructions.response_split)[-1].strip()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Instructions_summary():
|
| 52 |
+
instruction_summary = "Generate a one-sentence summary of this post."
|
| 53 |
+
response_split = "### Response:"
|
| 54 |
+
input_split = "### Input:"
|
| 55 |
+
instruction_split = "### Instruction:"
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def prompt_input(self, input):
|
| 59 |
+
# formulate the news
|
| 60 |
+
return f"### Instruction: {Instructions_summary.instruction_summary} ### Input: {input} ### Response: "
|
| 61 |
+
|
| 62 |
+
def get_prompt(self, query):
|
| 63 |
+
before_response = self.response_split.join(query.split(self.response_split)[:-1])
|
| 64 |
+
return before_response.rstrip()
|
| 65 |
+
|
| 66 |
+
def get_post(self, query):
|
| 67 |
+
before_response = self.get_prompt(query)
|
| 68 |
+
return before_response.split(self.input_split)[1].strip()
|
| 69 |
+
|
| 70 |
+
def get_input(self, query):
|
| 71 |
+
return self.get_prompt(query) + ' ' + self.response_split
|
| 72 |
+
|
| 73 |
+
def get_response(self, response):
|
| 74 |
+
return response.split(self.response_split)[-1].strip()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_dataset(path, tokenizer, rm_tokenizer, split='train', size=None):
|
| 78 |
+
ds = load_dataset(path, split=split)
|
| 79 |
+
if size is not None:
|
| 80 |
+
ds = ds.select(range(size))
|
| 81 |
+
|
| 82 |
+
def tokenize(sample, reject=False):
|
| 83 |
+
if not reject:
|
| 84 |
+
sample['text'] = sample['chosen']
|
| 85 |
+
else:
|
| 86 |
+
sample['text'] = sample['rejected']
|
| 87 |
+
split_text = sample['text'].split('\n\nAssistant:')
|
| 88 |
+
sample['prompt'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
|
| 89 |
+
sample["input_ids"] = tokenizer.encode(sample["prompt"])
|
| 90 |
+
sample["query"] = tokenizer.decode(sample["input_ids"])
|
| 91 |
+
sample['reward_ids'] = rm_tokenizer.encode(sample['text']) # for data filter
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
ds_concat = ds.map(tokenize, batched=False, fn_kwargs={"reject": False}, num_proc=30)
|
| 95 |
+
ds_concat = ds_concat.filter(lambda x: len(x["input_ids"]) <= 256 and len(x["input_ids"]) >= 8 and len(x['reward_ids']) <= 256 and len(x['reward_ids']) >= 8)
|
| 96 |
+
ds_concat = ds_concat.remove_columns(['rejected', 'chosen', 'reward_ids', 'text'])
|
| 97 |
+
|
| 98 |
+
ds_concat.set_format(type="torch")
|
| 99 |
+
return ds_concat
|
| 100 |
+
|
| 101 |
+
import re
|
| 102 |
+
|
| 103 |
+
def parse_conversation(text):
|
| 104 |
+
# 用正则匹配 "Human:" 或 "Assistant:" 开头的段落
|
| 105 |
+
parts = re.split(r'(?m)^(Human:|Assistant:)', text.strip())
|
| 106 |
+
|
| 107 |
+
messages = []
|
| 108 |
+
role_map = {
|
| 109 |
+
"Human:": "user",
|
| 110 |
+
"Assistant:": "assistant"
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# parts 会变成 ["", "Human:", "...内容...", "Assistant:", "...内容...", "Human:", "..."]
|
| 114 |
+
for i in range(1, len(parts), 2):
|
| 115 |
+
speaker = parts[i].strip()
|
| 116 |
+
content = parts[i+1].strip()
|
| 117 |
+
if content == '':
|
| 118 |
+
continue
|
| 119 |
+
messages.append({
|
| 120 |
+
"role": role_map[speaker],
|
| 121 |
+
"content": content
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
return messages
|
| 125 |
+
|
| 126 |
+
def build_dataset_llama3(path, tokenizer, rm_tokenizer, split='train', size=None):
|
| 127 |
+
ds = load_dataset(path, split=split)
|
| 128 |
+
if size is not None:
|
| 129 |
+
ds = ds.select(range(size))
|
| 130 |
+
|
| 131 |
+
def tokenize(sample, reject=False):
|
| 132 |
+
if not reject:
|
| 133 |
+
sample['text'] = sample['chosen']
|
| 134 |
+
else:
|
| 135 |
+
sample['text'] = sample['rejected']
|
| 136 |
+
# 把最后一轮当作目标:去掉最后一轮对话中的assistant,并组织好对话的prompt
|
| 137 |
+
raw_text = sample['text']
|
| 138 |
+
convs = parse_conversation(raw_text)
|
| 139 |
+
if convs[-1]['role'] == 'assistant':
|
| 140 |
+
convs = convs[:-1]
|
| 141 |
+
# sample['prompt'] = tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
|
| 142 |
+
prompt = tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
|
| 143 |
+
sample["input_ids"] = tokenizer.encode(prompt)
|
| 144 |
+
sample["query"] = tokenizer.decode(sample["input_ids"]) # same as 'prompt'
|
| 145 |
+
sample['reward_ids'] = rm_tokenizer.encode(sample['text']) # for data filter
|
| 146 |
+
# 组织reward model要用的template,去掉最后一个Assistant的回复
|
| 147 |
+
split_text = sample['text'].split('\n\nAssistant:')
|
| 148 |
+
sample['reward_query'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
|
| 149 |
+
|
| 150 |
+
# split_text = sample['text'].split('\n\nAssistant:')
|
| 151 |
+
# sample['prompt'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
|
| 152 |
+
# sample["input_ids"] = tokenizer.encode(sample["prompt"])
|
| 153 |
+
# sample["query"] = tokenizer.decode(sample["input_ids"])
|
| 154 |
+
# sample['reward_ids'] = rm_tokenizer.encode(sample['text']) # for data filter
|
| 155 |
+
return sample
|
| 156 |
+
|
| 157 |
+
ds_concat = ds.map(tokenize, batched=False, fn_kwargs={"reject": False}, num_proc=30)
|
| 158 |
+
ds_concat = ds_concat.filter(lambda x: len(x["input_ids"]) <= 256 and len(x["input_ids"]) >= 8 and len(x['reward_ids']) <= 256 and len(x['reward_ids']) >= 8)
|
| 159 |
+
ds_concat = ds_concat.remove_columns(['rejected', 'chosen', 'reward_ids', 'text'])
|
| 160 |
+
|
| 161 |
+
ds_concat.set_format(type="torch")
|
| 162 |
+
return ds_concat
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def build_dataset_summary(path, tokenizer, rm_tokenizer, split='train', size=None):
|
| 166 |
+
ds = load_dataset(path, 'comparisons')
|
| 167 |
+
ds = ds[split]
|
| 168 |
+
ds = ds.filter(lambda x: x["info"]['post'] is not None and 100 < len(x["info"]['post']) < 1200, batched=False, num_proc=30)
|
| 169 |
+
|
| 170 |
+
if size is not None:
|
| 171 |
+
ds = ds.select(range(size))
|
| 172 |
+
|
| 173 |
+
def tokenize(sample):
|
| 174 |
+
info_post = sample["info"]["post"].replace("\n", " ")
|
| 175 |
+
prompt_summary = Instructions_summary.prompt_input(info_post)
|
| 176 |
+
sample["prompt"] = prompt_summary
|
| 177 |
+
sample["input_ids"] = tokenizer.encode(sample["prompt"])
|
| 178 |
+
sample["query"] = tokenizer.decode(sample["input_ids"])
|
| 179 |
+
return sample
|
| 180 |
+
|
| 181 |
+
ds = ds.map(tokenize, batched=False, num_proc=30)
|
| 182 |
+
ds = ds.filter(lambda x: len(x["input_ids"]) <= 512 and len(x["input_ids"]) >= 8)
|
| 183 |
+
remove_columns = ['info', 'summaries', 'choice', 'worker', 'batch', 'split', 'extra']
|
| 184 |
+
ds = ds.remove_columns(remove_columns)
|
| 185 |
+
ds.set_format(type="torch")
|
| 186 |
+
return ds
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def build_dataset_eval(path, tokenizer, rm_tokenizers_list, split='test', size=None):
|
| 190 |
+
ds = load_dataset(path, split=split)
|
| 191 |
+
if size is not None:
|
| 192 |
+
ds = ds.select(range(size))
|
| 193 |
+
ds = ds.select(range(0, len(ds), 4))
|
| 194 |
+
rm_tokenizer1, rm_tokenizer2 = rm_tokenizers_list[:2]
|
| 195 |
+
def tokenize(sample):
|
| 196 |
+
sample['text'] = sample['chosen']
|
| 197 |
+
split_text = sample['text'].split('\n\nAssistant:')
|
| 198 |
+
sample['prompt'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
|
| 199 |
+
sample['response'] = split_text[-1].strip()
|
| 200 |
+
sample["input_ids"] = tokenizer.encode(sample["prompt"])
|
| 201 |
+
sample["query"] = tokenizer.decode(sample["input_ids"])
|
| 202 |
+
sample["input_ids_rm1"] = rm_tokenizer1.encode(sample["prompt"])
|
| 203 |
+
sample["input_ids_rm2"] = rm_tokenizer2.encode(sample["prompt"])
|
| 204 |
+
return sample
|
| 205 |
+
|
| 206 |
+
ds_chosen = ds.map(tokenize, batched=False, num_proc=20)
|
| 207 |
+
ds_concat = ds_chosen
|
| 208 |
+
ds_concat = ds_concat.filter(lambda x: len(x["input_ids"]) <= 512 and len(x["input_ids"]) >= 8 \
|
| 209 |
+
and len(x["input_ids_rm1"]) <= 512 and len(x["input_ids_rm1"]) >= 8
|
| 210 |
+
and len(x["input_ids_rm2"]) <= 512 and len(x["input_ids_rm2"]) >= 8 )
|
| 211 |
+
ds_concat = ds_concat.remove_columns(['chosen', 'rejected','input_ids_rm1', 'input_ids_rm2', 'text', 'prompt', 'response', 'query'])
|
| 212 |
+
ds_concat.set_format(type="torch")
|
| 213 |
+
return ds_concat
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def build_dataset_summary_eval(path, tokenizer, rm_tokenizers, split='test', size=None):
|
| 217 |
+
if split == 'test':
|
| 218 |
+
split = 'validation'
|
| 219 |
+
ds = load_dataset(path, 'comparisons')
|
| 220 |
+
ds = ds[split]
|
| 221 |
+
ds = ds.filter(lambda x: x["info"]['post'] is not None and 100 < len(x["info"]['post']) < 1200, batched=False, num_proc=30)
|
| 222 |
+
|
| 223 |
+
# need to remove duplicated prompts for evaluation
|
| 224 |
+
def remove_duplicate(duplicated_dataset):
|
| 225 |
+
duplicated_dataset = duplicated_dataset.filter(lambda x: x['info']["id"] is not None)
|
| 226 |
+
initial_list = duplicated_dataset.map(lambda x: {"id": x['info']["id"]})
|
| 227 |
+
_ , unique_indices = np.unique(initial_list["id"], return_index=True, axis=0)
|
| 228 |
+
filtered_dataset = duplicated_dataset.select(unique_indices.tolist())
|
| 229 |
+
return filtered_dataset
|
| 230 |
+
|
| 231 |
+
ds = remove_duplicate(ds)
|
| 232 |
+
if size is not None:
|
| 233 |
+
ds = ds.select(range(size))
|
| 234 |
+
ds = ds.select(range(0, min(len(ds),2000))) # select 2000 data
|
| 235 |
+
|
| 236 |
+
def tokenize(sample):
|
| 237 |
+
info_post = sample["info"]["post"].replace("\n", " ")
|
| 238 |
+
prompt_summary = Instructions_summary.prompt_input(info_post)
|
| 239 |
+
sample["prompt"] = prompt_summary
|
| 240 |
+
sample["input_ids"] = tokenizer.encode(prompt_summary)
|
| 241 |
+
sample["query"] = tokenizer.decode(sample["input_ids"])
|
| 242 |
+
return sample
|
| 243 |
+
|
| 244 |
+
ds = ds.map(tokenize, batched=False, num_proc=30)
|
| 245 |
+
ds = ds.filter(lambda x: len(x["input_ids"]) <= 512 and len(x["input_ids"]) >= 8)
|
| 246 |
+
remove_columns = ['info', 'summaries', 'choice', 'worker', 'batch', 'split', 'extra']
|
| 247 |
+
ds = ds.remove_columns(remove_columns)
|
| 248 |
+
ds.set_format(type="torch")
|
| 249 |
+
return ds
|
| 250 |
+
|
| 251 |
+
def check_lora_in_model_path(model, path):
|
| 252 |
+
if os.path.exists(path):
|
| 253 |
+
dirnames = os.listdir(path)
|
| 254 |
+
if 'adapter_config.json' in dirnames:
|
| 255 |
+
return True
|
| 256 |
+
|
| 257 |
+
state_dict_keys = model.state_dict().keys()
|
| 258 |
+
for key in state_dict_keys:
|
| 259 |
+
if 'lora' in key:
|
| 260 |
+
return True
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def load_reward_model(reward_peft_path, gpu_id):
|
| 265 |
+
num_labels = 2 if ('humor' in reward_peft_path or 'faithful' in reward_peft_path) else 1
|
| 266 |
+
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
| 267 |
+
reward_peft_path,
|
| 268 |
+
num_labels=num_labels, torch_dtype=torch.bfloat16,
|
| 269 |
+
device_map=gpu_id,
|
| 270 |
+
)
|
| 271 |
+
if check_lora_in_model_path(reward_model, reward_peft_path):
|
| 272 |
+
reward_model = PeftModel.from_pretrained(reward_model, reward_peft_path)
|
| 273 |
+
if hasattr(reward_model, 'merge_and_unload'):
|
| 274 |
+
reward_model = reward_model.merge_and_unload() # merge lora weights
|
| 275 |
+
return reward_model.to(gpu_id)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def load_main_tokenizer(tokenier_name):
|
| 279 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 280 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 281 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
| 282 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
| 283 |
+
|
| 284 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenier_name, use_fast = False)
|
| 285 |
+
tokenizer.add_special_tokens(
|
| 286 |
+
{
|
| 287 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
| 288 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
| 289 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
| 290 |
+
"pad_token": DEFAULT_PAD_TOKEN,
|
| 291 |
+
}
|
| 292 |
+
)
|
| 293 |
+
return tokenizer
|
| 294 |
+
|
| 295 |
+
def load_llama3_tokenizer(tokenier_name):
|
| 296 |
+
DEFAULT_PAD_TOKEN = "<|end_of_text|>"
|
| 297 |
+
DEFAULT_EOS_TOKEN = "<|end_of_text|>"
|
| 298 |
+
DEFAULT_BOS_TOKEN = "<|begin_of_text|>"
|
| 299 |
+
DEFAULT_UNK_TOKEN = "<|end_of_text|>"
|
| 300 |
+
|
| 301 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenier_name, use_fast = False)
|
| 302 |
+
tokenizer.add_special_tokens(
|
| 303 |
+
{
|
| 304 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
| 305 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
| 306 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
| 307 |
+
"pad_token": DEFAULT_PAD_TOKEN,
|
| 308 |
+
}
|
| 309 |
+
)
|
| 310 |
+
return tokenizer
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_rewards(reward_model, texts_for_rewards, reward_mean_std=None, sub_position=0):
|
| 314 |
+
rewards = []
|
| 315 |
+
print('log: reward model forwarding ...')
|
| 316 |
+
with torch.no_grad():
|
| 317 |
+
pbar = tqdm(total=len(texts_for_rewards))
|
| 318 |
+
for inputs in texts_for_rewards:
|
| 319 |
+
if sub_position != 0: # for multiple output
|
| 320 |
+
rewards.append(reward_model(**(inputs.to(reward_model.device))).logits[0][sub_position])
|
| 321 |
+
else:
|
| 322 |
+
rewards.append(reward_model(**(inputs.to(reward_model.device))).logits[0])
|
| 323 |
+
pbar.update(1)
|
| 324 |
+
|
| 325 |
+
if reward_mean_std is None:
|
| 326 |
+
rewards = [np.round(r.cpu().detach().item(), 1) for r in rewards]
|
| 327 |
+
else:
|
| 328 |
+
mean_reward, std_reward = reward_mean_std
|
| 329 |
+
rewards = [np.round((r.cpu().detach().item() - mean_reward) / std_reward, 1) for r in rewards]
|
| 330 |
+
return rewards
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def save_configs(config, path):
|
| 335 |
+
if not os.path.exists(path):
|
| 336 |
+
os.makedirs(path, exist_ok=True)
|
| 337 |
+
with open(os.path.join(path, 'training_config.txt'), 'w+') as f:
|
| 338 |
+
if type(config) == dict:
|
| 339 |
+
lines = [key + ' : ' + config[key] + '\n' for key in config.keys()]
|
| 340 |
+
f.writelines(lines)
|
| 341 |
+
else:
|
| 342 |
+
f.writelines(str(config))
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_average_state_dict(state_dicts, coefficients):
|
| 346 |
+
i = 0
|
| 347 |
+
for state_dict, coefficient in zip(state_dicts, coefficients):
|
| 348 |
+
current_weights = state_dict
|
| 349 |
+
for key in list(current_weights.keys()):
|
| 350 |
+
if i == 0:
|
| 351 |
+
state_dicts[0][key] = coefficient * current_weights[key]
|
| 352 |
+
else:
|
| 353 |
+
state_dicts[0][key] += coefficient * current_weights[key]
|
| 354 |
+
i += 1
|
| 355 |
+
return state_dicts[0]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def merge_weights_with_preference(base_model_names, preference, temp_save_path):
|
| 359 |
+
models = []
|
| 360 |
+
for base_model_name in base_model_names:
|
| 361 |
+
model_tmp = AutoModelForCausalLM.from_pretrained(
|
| 362 |
+
base_model_name,
|
| 363 |
+
device_map='cpu',
|
| 364 |
+
)
|
| 365 |
+
models.append(model_tmp)
|
| 366 |
+
|
| 367 |
+
state_dicts = [model_tmp.state_dict() for model_tmp in models]
|
| 368 |
+
average_weights = get_average_state_dict(state_dicts, preference)
|
| 369 |
+
model_1 = models[0]
|
| 370 |
+
model_1.load_state_dict(average_weights, strict=False)
|
| 371 |
+
if os.path.exists(temp_save_path):
|
| 372 |
+
shutil.rmtree(temp_save_path, ignore_errors=True)
|
| 373 |
+
model_1.save_pretrained(temp_save_path)
|
| 374 |
+
|
| 375 |
+
while len(models):
|
| 376 |
+
del models[0]
|
| 377 |
+
while len(state_dicts):
|
| 378 |
+
del state_dicts[0]
|
| 379 |
+
del average_weights
|
| 380 |
+
gc.collect()
|
| 381 |
+
torch.cuda.empty_cache()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def merge_lora_weight(model, path):
|
| 385 |
+
if check_lora_in_model_path(model, path):
|
| 386 |
+
model = PeftModel.from_pretrained(model, path)
|
| 387 |
+
model = model.merge_and_unload()
|
| 388 |
+
return model
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def get_clean_data(full_responses, full_prompts, remove_bad=False):
|
| 392 |
+
full_prompts_clean = []
|
| 393 |
+
full_responses_clean = []
|
| 394 |
+
for i, response in enumerate(full_responses):
|
| 395 |
+
full_prompts[i] = full_prompts[i].strip('[PAD] ').strip('[PAD]').strip('<s>').strip('</s>').strip()
|
| 396 |
+
response = response.strip('[PAD] ').strip('[PAD]').strip('<s>').strip('</s>')
|
| 397 |
+
temp_resp = response.replace(full_prompts[i], '').strip().strip('\n\n----').strip('\n\n----- ').strip()
|
| 398 |
+
if '</s>' in temp_resp:
|
| 399 |
+
temp_resp = temp_resp[:temp_resp.rindex('</s>')]
|
| 400 |
+
temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
|
| 401 |
+
temp_resp = temp_resp.split('\nHuman:')[0].strip()
|
| 402 |
+
temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
|
| 403 |
+
temp_resp = temp_resp.split('\nAssistant:')[0].strip()
|
| 404 |
+
temp_resp = temp_resp.split('\n\n\n')[0].strip()
|
| 405 |
+
clean_resp = full_prompts[i] + ' ' + temp_resp
|
| 406 |
+
if remove_bad and (('.....' in clean_resp) or (clean_resp.count(':)') >= 3)):
|
| 407 |
+
## pass bad sample
|
| 408 |
+
continue
|
| 409 |
+
full_responses_clean.append(clean_resp)
|
| 410 |
+
full_prompts_clean.append(full_prompts[i])
|
| 411 |
+
return full_prompts_clean, full_responses_clean
|
push_model.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface-cli repo create tmp --type model
|
| 2 |
+
cd /home/hector5/RiC/merged_models/Llama-3-Base-8B-SFT-morlhf-equal
|
| 3 |
+
git init
|
| 4 |
+
git lfs install --skip-smudge
|
| 5 |
+
git remote add origin https://huggingface.co/hour1/tmp
|
| 6 |
+
git add .
|
| 7 |
+
git commit -m "Initial commit"
|
| 8 |
+
git push origin main
|
push_with_hf.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
hf upload hour1/humor .
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# my_llm
|
| 2 |
+
|
| 3 |
+
datasets==3.2.0
|
| 4 |
+
torch==2.2.0
|
| 5 |
+
tqdm==4.67.1
|
| 6 |
+
transformers==4.45.1
|
| 7 |
+
tokenizers==0.20.3
|
| 8 |
+
accelerate==0.34.0
|
| 9 |
+
deepspeed==0.11.2
|
| 10 |
+
tyro==0.7.3
|
| 11 |
+
peft==0.10.0
|
| 12 |
+
matplotlib==3.7.1
|
| 13 |
+
wandb==0.15.4
|
| 14 |
+
bitsandbytes==0.43.2
|
| 15 |
+
trl==0.8.0
|
| 16 |
+
safetensors==0.4.5
|
| 17 |
+
scikit-learn==1.4.1.post1
|
| 18 |
+
scipy==1.10.1
|
| 19 |
+
seaborn==0.13.2
|
| 20 |
+
protobuf==3.20.0
|
| 21 |
+
numpy==1.25.0
|
| 22 |
+
flash-attn==2.6.3
|
ric/evaluation.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, DataCollatorWithPadding
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
from trl import set_seed
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from utils import get_clean_data, Instructions_n, build_dataset_with_preference_n, \
|
| 14 |
+
load_main_tokenizer, Instructions_summary_n
|
| 15 |
+
from multi_reward_models import RewardModels
|
| 16 |
+
tqdm.pandas()
|
| 17 |
+
from utils import save_configs, map_rewards_from_preference, build_summary_dataset_with_preference_n, clean_gpu_memory
|
| 18 |
+
|
| 19 |
+
# define paths for two datasets
|
| 20 |
+
hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
|
| 21 |
+
summary_dataset_path = 'openai/summarize_from_feedback'
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ScriptArguments:
|
| 25 |
+
peft_name: Optional[str] = field(default='', metadata={'help': "path of the peft files"})
|
| 26 |
+
num_prefer_points: Optional[int] = field(default=10)
|
| 27 |
+
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
|
| 28 |
+
save_directory: Optional[str] = field(default='./logs_trl_eval/')
|
| 29 |
+
wandb_name: Optional[str] = field(default='test', metadata={"help": "Name for this experiment"})
|
| 30 |
+
reward_names:Optional[str] = field(default='harmless,helpful')
|
| 31 |
+
base_model_name: Optional[str] = field(default='meta-llama/Llama-2-7b-hf', metadata={"help": "local path to the base model or the huggingface id"})
|
| 32 |
+
reward_stats_path: Optional[str] = field(default='')
|
| 33 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 37 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 38 |
+
exp_type = script_args.exp_type
|
| 39 |
+
base_model_name = script_args.base_model_name
|
| 40 |
+
tokenier_name = script_args.base_model_name
|
| 41 |
+
reward_stats_path = script_args.reward_stats_path if len(script_args.reward_stats_path) else None
|
| 42 |
+
print('base model: ', base_model_name)
|
| 43 |
+
|
| 44 |
+
peft_name = script_args.peft_name
|
| 45 |
+
reward_names = [x.strip() for x in script_args.reward_names.split(',')]
|
| 46 |
+
print(reward_names)
|
| 47 |
+
reward_path_tokenizer_dict = {
|
| 48 |
+
'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
|
| 49 |
+
'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
|
| 50 |
+
'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
|
| 51 |
+
'summary': ['Tristan/gpt2_reward_summarization'],
|
| 52 |
+
'faithful':['CogComp/bart-faithful-summary-detector'],
|
| 53 |
+
'humor': ['mohameddhiab/humor-no-humor'],
|
| 54 |
+
}
|
| 55 |
+
reward_model_path_list = []
|
| 56 |
+
rm_tokenizer_path_list = []
|
| 57 |
+
for name in reward_names:
|
| 58 |
+
if name not in reward_path_tokenizer_dict.keys():
|
| 59 |
+
raise NotImplementedError
|
| 60 |
+
reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 61 |
+
rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
save_info = {
|
| 65 |
+
'base_model_name': base_model_name,
|
| 66 |
+
'peft_name': peft_name,
|
| 67 |
+
'tokenier_name': tokenier_name
|
| 68 |
+
}
|
| 69 |
+
for i in range(len(reward_model_path_list)):
|
| 70 |
+
save_info['reward_peft_path{}'.format(i+1)] = reward_model_path_list[i]
|
| 71 |
+
save_configs(save_info, os.path.join(script_args.save_directory, script_args.wandb_name))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
process_id = Accelerator().local_process_index
|
| 75 |
+
gpu_id = process_id
|
| 76 |
+
print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
|
| 77 |
+
|
| 78 |
+
set_seed(8888)
|
| 79 |
+
tokenizer = load_main_tokenizer(tokenier_name)
|
| 80 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
+
base_model_name,
|
| 82 |
+
torch_dtype=torch.bfloat16, # fast inference
|
| 83 |
+
device_map=gpu_id,
|
| 84 |
+
)
|
| 85 |
+
############# very important for padding
|
| 86 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 87 |
+
if len(peft_name) > 0:
|
| 88 |
+
model = PeftModel.from_pretrained(model, peft_name)
|
| 89 |
+
if hasattr(model, 'merge_and_unload'):
|
| 90 |
+
model = model.merge_and_unload()
|
| 91 |
+
|
| 92 |
+
# load reward model
|
| 93 |
+
# do not normalization for evaluation
|
| 94 |
+
reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id, reward_stats_path)
|
| 95 |
+
num_rewards = len(reward_model_path_list)
|
| 96 |
+
instructions = Instructions_n(num_rewards) if exp_type == 'assistant' else Instructions_summary_n(num_rewards)
|
| 97 |
+
|
| 98 |
+
generation_kwargs = {
|
| 99 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 100 |
+
"min_length": -1,
|
| 101 |
+
"top_k": 0.0,
|
| 102 |
+
"top_p": 0.9,
|
| 103 |
+
"do_sample": True,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
### for evaluation
|
| 107 |
+
print('evaluation........')
|
| 108 |
+
tokenizer.padding_side = "left"
|
| 109 |
+
|
| 110 |
+
## preference list
|
| 111 |
+
if reward_models.num_rewards == 2:
|
| 112 |
+
N = script_args.num_prefer_points
|
| 113 |
+
preferences = np.zeros((N+1, 2))
|
| 114 |
+
preferences[:, 0] = np.arange(0,1 + 1/N, 1/N)
|
| 115 |
+
preferences[:, 1] = 1 - preferences[:, 0]
|
| 116 |
+
preferences = np.round(preferences, 1)
|
| 117 |
+
elif reward_models.num_rewards == 3:
|
| 118 |
+
preferences = np.array([
|
| 119 |
+
[0.0, 0.0, 1.0],
|
| 120 |
+
[0.0, 1.0, 0.0],
|
| 121 |
+
[0.1, 0.1, 0.8],
|
| 122 |
+
[0.1, 0.8, 0.1],
|
| 123 |
+
[0.2, 0.2, 0.6],
|
| 124 |
+
[0.2, 0.4, 0.4],
|
| 125 |
+
[0.2, 0.6, 0.2],
|
| 126 |
+
[0.33, 0.33, 0.33],
|
| 127 |
+
[0.4, 0.4, 0.2],
|
| 128 |
+
[0.4, 0.2, 0.4],
|
| 129 |
+
[0.6, 0.2, 0.2],
|
| 130 |
+
[0.8, 0.1, 0.1],
|
| 131 |
+
[1.0, 0.0, 0.0],
|
| 132 |
+
])
|
| 133 |
+
else:
|
| 134 |
+
raise NotImplementedError
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# using Gaussian rewards as reference
|
| 138 |
+
rewards_reference_list = [np.random.randn(50000) for _ in range(len(preferences[0]))]
|
| 139 |
+
|
| 140 |
+
def evaluate_model(
|
| 141 |
+
model,
|
| 142 |
+
reward_models,
|
| 143 |
+
tokenizer,
|
| 144 |
+
target_rewards,
|
| 145 |
+
instructions,
|
| 146 |
+
gpu_id):
|
| 147 |
+
if exp_type == 'assistant':
|
| 148 |
+
valid_dataset = build_dataset_with_preference_n(hhrlhf_dataset_path, tokenizer, rm_tokenizers, target_rewards, split='test')
|
| 149 |
+
else:
|
| 150 |
+
valid_dataset = build_summary_dataset_with_preference_n(summary_dataset_path, tokenizer, rm_tokenizers, target_rewards, split='test')
|
| 151 |
+
print(f"Size of the validation set: {len(valid_dataset)}")
|
| 152 |
+
valid_batch_size = 1
|
| 153 |
+
valid_dataset = valid_dataset.remove_columns('input_ids')
|
| 154 |
+
valid_dataset = valid_dataset.rename_column('prompt_with_score_ids', 'input_ids')
|
| 155 |
+
valid_dataset = valid_dataset.remove_columns(['prompt', 'response', 'query', 'score1', 'score2', 'prompt_with_score'])
|
| 156 |
+
for key in ['key', 'text']:
|
| 157 |
+
if key in valid_dataset.column_names:
|
| 158 |
+
valid_dataset = valid_dataset.remove_columns(key)
|
| 159 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 160 |
+
valid_data_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, drop_last=True, collate_fn=data_collator)
|
| 161 |
+
accelerator = Accelerator()
|
| 162 |
+
model, valid_data_loader = accelerator.prepare(model, valid_data_loader)
|
| 163 |
+
|
| 164 |
+
full_response_tensors = []
|
| 165 |
+
full_prompts = []
|
| 166 |
+
pbar = tqdm(total=len(valid_dataset) // valid_batch_size // accelerator.num_processes)
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
for i, batch in enumerate(valid_data_loader):
|
| 169 |
+
response_tensors = accelerator.unwrap_model(model).generate(batch['input_ids'], **generation_kwargs)
|
| 170 |
+
full_response_tensors.extend(response_tensors)
|
| 171 |
+
full_prompts.extend(batch['input_ids'])
|
| 172 |
+
pbar.update(1)
|
| 173 |
+
|
| 174 |
+
full_prompts = tokenizer.batch_decode(full_prompts)
|
| 175 |
+
full_responses = tokenizer.batch_decode(full_response_tensors)
|
| 176 |
+
# clean data
|
| 177 |
+
full_prompts, full_responses = get_clean_data(full_responses, full_prompts)
|
| 178 |
+
|
| 179 |
+
# Compute score
|
| 180 |
+
queries_responses = [(instructions.get_input(text), instructions.get_response(text)) for text in full_responses]
|
| 181 |
+
if hasattr(instructions, 'get_post'):
|
| 182 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
|
| 183 |
+
else:
|
| 184 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses)
|
| 185 |
+
|
| 186 |
+
desired_rewards_list = [[] for _ in range(reward_models.num_rewards)]
|
| 187 |
+
for text in full_responses:
|
| 188 |
+
desired_scores = instructions.get_scores(text)
|
| 189 |
+
for i in range(reward_models.num_rewards):
|
| 190 |
+
desired_rewards_list[i].append(float(desired_scores[i]))
|
| 191 |
+
|
| 192 |
+
accelerator = Accelerator()
|
| 193 |
+
accelerator.wait_for_everyone()
|
| 194 |
+
### merge data
|
| 195 |
+
### error here may because of old version of transformers/accelerate/peft
|
| 196 |
+
all_rewards = []
|
| 197 |
+
all_desired_rewards = []
|
| 198 |
+
for i in range(reward_models.num_rewards):
|
| 199 |
+
all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
|
| 200 |
+
all_desired_rewards.append(accelerator.gather_for_metrics(desired_rewards_list[i]))
|
| 201 |
+
all_full_prompts = accelerator.gather_for_metrics(full_prompts)
|
| 202 |
+
all_full_responses = accelerator.gather_for_metrics(full_responses)
|
| 203 |
+
return all_rewards, all_desired_rewards, all_full_prompts, all_full_responses
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
rm_tokenizers = []
|
| 207 |
+
for i in range(num_rewards):
|
| 208 |
+
rm_tokenizers.append(AutoTokenizer.from_pretrained(rm_tokenizer_path_list[i]))
|
| 209 |
+
|
| 210 |
+
for k in range(len(preferences)):
|
| 211 |
+
preference = preferences[k]
|
| 212 |
+
target_rewards = map_rewards_from_preference(rewards_reference_list, preference, method='l2').reshape(-1)
|
| 213 |
+
print(k, target_rewards)
|
| 214 |
+
all_rewards, all_desired_rewards, all_full_prompts, all_full_responses = evaluate_model(model, reward_models, tokenizer, target_rewards, instructions, gpu_id)
|
| 215 |
+
if process_id == 0:
|
| 216 |
+
evaluation_result = {
|
| 217 |
+
'prompt': all_full_prompts,
|
| 218 |
+
'response': all_full_responses,
|
| 219 |
+
}
|
| 220 |
+
for i in range(num_rewards):
|
| 221 |
+
evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
|
| 222 |
+
evaluation_result['desired_score{}'.format(i+1)] = all_desired_rewards[i]
|
| 223 |
+
|
| 224 |
+
print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
|
| 225 |
+
print('total average desired score {}: {}'.format(i+1, np.mean(evaluation_result['desired_score{}'.format(i+1)])))
|
| 226 |
+
|
| 227 |
+
dataframe = pd.DataFrame(evaluation_result)
|
| 228 |
+
if len(preference) == 3:
|
| 229 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}_{}.csv'.format(preference[0], preference[1], preference[2])))
|
| 230 |
+
else:
|
| 231 |
+
dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}.csv'.format(preference[0], preference[1])))
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
ric/generation.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from accelerate import Accelerator
|
| 3 |
+
import torch
|
| 4 |
+
from datasets import load_from_disk, disable_caching
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from transformers import AutoModelForCausalLM, DataCollatorWithPadding
|
| 7 |
+
from peft import PeftModel
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from utils import get_clean_data, load_main_tokenizer, \
|
| 12 |
+
reset_score_in_dataset, Instructions_n, clean_gpu_memory, Instructions_summary_n
|
| 13 |
+
from multi_reward_models import RewardModels
|
| 14 |
+
tqdm.pandas()
|
| 15 |
+
disable_caching()
|
| 16 |
+
from trl import set_seed
|
| 17 |
+
|
| 18 |
+
def generate_data(
|
| 19 |
+
model_path,
|
| 20 |
+
reward_model_path_list,
|
| 21 |
+
dataset,
|
| 22 |
+
tokenizer_name,
|
| 23 |
+
rm_tokenizer_path_list,
|
| 24 |
+
save_path,
|
| 25 |
+
reward_stats_path,
|
| 26 |
+
iter=0,
|
| 27 |
+
peft_name=None,
|
| 28 |
+
args=None,
|
| 29 |
+
exp_type='assistant',
|
| 30 |
+
):
|
| 31 |
+
set_seed(8888 + iter)
|
| 32 |
+
print('Generating ...')
|
| 33 |
+
process_id = Accelerator().local_process_index
|
| 34 |
+
gpu_id = process_id
|
| 35 |
+
print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
|
| 36 |
+
tokenizer = load_main_tokenizer(tokenizer_name)
|
| 37 |
+
## load model
|
| 38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 39 |
+
model_path,
|
| 40 |
+
torch_dtype=torch.bfloat16, # fast inference
|
| 41 |
+
device_map=gpu_id,
|
| 42 |
+
)
|
| 43 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 44 |
+
if peft_name is not None:
|
| 45 |
+
model = PeftModel.from_pretrained(model, peft_name)
|
| 46 |
+
if hasattr(model, 'merge_and_unload'):
|
| 47 |
+
model = model.merge_and_unload()
|
| 48 |
+
|
| 49 |
+
generation_kwargs = {
|
| 50 |
+
"max_new_tokens": 128 if exp_type == 'assistant' else 48,
|
| 51 |
+
"min_length": -1,
|
| 52 |
+
"top_k": 0.0,
|
| 53 |
+
"top_p": 0.9,
|
| 54 |
+
'temperature': 0.7,
|
| 55 |
+
"do_sample": True,
|
| 56 |
+
"begin_suppress_tokens": [tokenizer.eos_token_id],
|
| 57 |
+
}
|
| 58 |
+
### for evaluation
|
| 59 |
+
print('generation........')
|
| 60 |
+
tokenizer.padding_side = "left"
|
| 61 |
+
|
| 62 |
+
if type(dataset) == str:
|
| 63 |
+
dataset = load_from_disk(dataset)
|
| 64 |
+
select_index = np.random.randint(0, len(dataset), args.num_generation_samples)
|
| 65 |
+
selected_dataset = dataset.select(select_index)
|
| 66 |
+
scores_name_list = []
|
| 67 |
+
for key in dataset.column_names:
|
| 68 |
+
if key.startswith('score'):
|
| 69 |
+
scores_name_list.append(key)
|
| 70 |
+
|
| 71 |
+
selected_dataset = reset_score_in_dataset(selected_dataset, tokenizer, exp_type=exp_type) #, [dataset[key] for key in scores_name_list])
|
| 72 |
+
print(f"Size of the selected dataset: {len(selected_dataset)}")
|
| 73 |
+
|
| 74 |
+
batch_size = 1 # batch size > 1 will pad on the left and will negatively influence generation
|
| 75 |
+
remove_columns = []
|
| 76 |
+
for name in ['input_ids', 'prompt', 'text', 'response', 'query', 'prompt_with_score'] + scores_name_list:
|
| 77 |
+
if name in selected_dataset.column_names:
|
| 78 |
+
remove_columns.append(name)
|
| 79 |
+
selected_dataset = selected_dataset.remove_columns(remove_columns)
|
| 80 |
+
selected_dataset = selected_dataset.rename_column('prompt_with_score_ids', 'input_ids')
|
| 81 |
+
|
| 82 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 83 |
+
data_loader = DataLoader(selected_dataset, batch_size=batch_size, drop_last=True, collate_fn=data_collator)
|
| 84 |
+
|
| 85 |
+
accelerator = Accelerator()
|
| 86 |
+
model, data_loader = accelerator.prepare(model, data_loader)
|
| 87 |
+
|
| 88 |
+
full_response_tensors = []
|
| 89 |
+
full_prompts = []
|
| 90 |
+
|
| 91 |
+
pbar = tqdm(total=len(selected_dataset) // batch_size // accelerator.num_processes)
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
for i, batch in enumerate(data_loader):
|
| 94 |
+
response_tensors = accelerator.unwrap_model(model).generate(batch['input_ids'], **generation_kwargs)
|
| 95 |
+
full_response_tensors.extend(response_tensors)
|
| 96 |
+
full_prompts.extend(batch['input_ids'])
|
| 97 |
+
pbar.update(1)
|
| 98 |
+
|
| 99 |
+
## decoding
|
| 100 |
+
full_prompts = tokenizer.batch_decode(full_prompts)
|
| 101 |
+
full_responses = tokenizer.batch_decode(full_response_tensors)
|
| 102 |
+
full_prompts, full_responses = get_clean_data(full_responses, full_prompts, remove_bad=True)
|
| 103 |
+
|
| 104 |
+
## del model and clear gpu
|
| 105 |
+
del model, data_loader
|
| 106 |
+
clean_gpu_memory()
|
| 107 |
+
|
| 108 |
+
# Compute score
|
| 109 |
+
reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id, reward_stats_path)
|
| 110 |
+
if exp_type == 'assistant':
|
| 111 |
+
instructions = Instructions_n(reward_models.num_rewards)
|
| 112 |
+
else:
|
| 113 |
+
instructions = Instructions_summary_n(reward_models.num_rewards)
|
| 114 |
+
|
| 115 |
+
queries_responses = [(instructions.get_input(text), instructions.get_response(text)) for text in full_responses]
|
| 116 |
+
|
| 117 |
+
if hasattr(instructions, 'get_post'):
|
| 118 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
|
| 119 |
+
else:
|
| 120 |
+
rewards_list = reward_models.get_reward_model_scores(queries_responses)
|
| 121 |
+
|
| 122 |
+
desired_rewards_list = [[] for _ in range(reward_models.num_rewards)]
|
| 123 |
+
for text in full_responses:
|
| 124 |
+
desired_scores = instructions.get_scores(text)
|
| 125 |
+
for i in range(reward_models.num_rewards):
|
| 126 |
+
desired_rewards_list[i].append(float(desired_scores[i]))
|
| 127 |
+
|
| 128 |
+
### merge data
|
| 129 |
+
### error here may because of old version of transformers/accelerate/peft
|
| 130 |
+
all_rewards = []
|
| 131 |
+
all_desired_rewards = []
|
| 132 |
+
for i in range(len(rewards_list)):
|
| 133 |
+
all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
|
| 134 |
+
all_desired_rewards.append(accelerator.gather_for_metrics(desired_rewards_list[i]))
|
| 135 |
+
all_full_prompts = accelerator.gather_for_metrics(full_prompts)
|
| 136 |
+
all_full_responses = accelerator.gather_for_metrics(full_responses)
|
| 137 |
+
|
| 138 |
+
if process_id == 0:
|
| 139 |
+
evaluation_result = {
|
| 140 |
+
'prompt': all_full_prompts,
|
| 141 |
+
'response': all_full_responses,
|
| 142 |
+
}
|
| 143 |
+
for i in range(len(rewards_list)):
|
| 144 |
+
evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
|
| 145 |
+
evaluation_result['desired_score{}'.format(i+1)] = all_desired_rewards[i]
|
| 146 |
+
|
| 147 |
+
print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
|
| 148 |
+
print('total average desired score {}: {}'.format(i+1, np.mean(evaluation_result['desired_score{}'.format(i+1)])))
|
| 149 |
+
|
| 150 |
+
dataframe = pd.DataFrame(evaluation_result)
|
| 151 |
+
dataframe.to_csv(os.path.join(save_path,'data.csv'))
|
| 152 |
+
|
| 153 |
+
# wait for the main process
|
| 154 |
+
accelerator.wait_for_everyone()
|
| 155 |
+
|
| 156 |
+
|
ric/main.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from accelerate import Accelerator
|
| 6 |
+
import torch
|
| 7 |
+
from utils import clean_gpu_memory, merge_dataset, save_configs
|
| 8 |
+
from training import train_model
|
| 9 |
+
from generation import generate_data
|
| 10 |
+
from transformers import HfArgumentParser
|
| 11 |
+
|
| 12 |
+
# define paths for two datasets
|
| 13 |
+
hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
|
| 14 |
+
summary_dataset_path = 'openai/summarize_from_feedback'
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ScriptArguments:
|
| 18 |
+
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
| 19 |
+
disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
|
| 20 |
+
save_directory: Optional[str] = field(default='./logs_trl/', metadata={'help':'path'})
|
| 21 |
+
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate for online training"})
|
| 22 |
+
batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
|
| 23 |
+
training_steps: Optional[int] = field(default=20000, metadata={'help': 'number of training steps in the offline training'})
|
| 24 |
+
online_training_steps: Optional[int] = field(default=4000, metadata={'help': 'number of training steps in the online training'})
|
| 25 |
+
gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
|
| 26 |
+
num_online_iterations: Optional[int] = field(default=1, metadata={'help': 'number of the online generation and training'})
|
| 27 |
+
num_generation_samples: Optional[int] = field(default=20000, metadata={'help': 'number of samples generated'})
|
| 28 |
+
max_grad_norm: Optional[float] = field(default=1, metadata={"help": "Maximum gradient norm for gradient clipping"})
|
| 29 |
+
quantile_threshold: Optional[float] = field(default=0.7)
|
| 30 |
+
num_origin_samples: Optional[int] = field(default=10000)
|
| 31 |
+
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "loading model in 8 bit or bfloat16"})
|
| 32 |
+
bf16: Optional[bool] = field(default=False,
|
| 33 |
+
metadata={"help": "if True, training with bfloat16 (not supported by V100, but for A100, A40, A6000), otherwise we use fp32"}
|
| 34 |
+
)
|
| 35 |
+
wandb_name: Optional[str] = field(default='ric_assistant_harmlesshelpful_offline20000_lr1e-4', metadata={"help": "Name for this experiment"})
|
| 36 |
+
base_model_name: Optional[str] = field(default='meta-llama/Llama-2-7b-hf', metadata={"help": "local path to the base model or the huggingface id"})
|
| 37 |
+
peft_name: Optional[str] = field(default='', metadata={"help": "local path to the peft model"})
|
| 38 |
+
reward_names:Optional[str] = field(default='harmless,helpful')
|
| 39 |
+
train_dataset_path: Optional[str] = field(default='./datasets/all_full_train_harmhelp.hf')
|
| 40 |
+
train_reward_stats_path: Optional[str] = field(default='')
|
| 41 |
+
exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 45 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 46 |
+
exp_type = script_args.exp_type
|
| 47 |
+
base_model_name = script_args.base_model_name
|
| 48 |
+
tokenizer_name = script_args.base_model_name
|
| 49 |
+
print('base model: ', base_model_name)
|
| 50 |
+
|
| 51 |
+
if script_args.disable_wandb: # if you don't need the wandb log
|
| 52 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
| 53 |
+
|
| 54 |
+
reward_names = [x.strip() for x in script_args.reward_names.split(',')]
|
| 55 |
+
reward_path_tokenizer_dict = {
|
| 56 |
+
'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
|
| 57 |
+
'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
|
| 58 |
+
'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
|
| 59 |
+
'summary': ['Tristan/gpt2_reward_summarization'],
|
| 60 |
+
'faithful':['CogComp/bart-faithful-summary-detector'],
|
| 61 |
+
'humor': ['mohameddhiab/humor-no-humor'],
|
| 62 |
+
}
|
| 63 |
+
reward_model_path_list = []
|
| 64 |
+
rm_tokenizer_path_list = []
|
| 65 |
+
for name in reward_names:
|
| 66 |
+
if name not in reward_path_tokenizer_dict.keys():
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 69 |
+
rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
|
| 70 |
+
|
| 71 |
+
# Remember to generate these datasets before training
|
| 72 |
+
train_dataset_path = script_args.train_dataset_path
|
| 73 |
+
reward_stats_path = script_args.train_reward_stats_path if len(script_args.train_reward_stats_path) else script_args.train_dataset_path + '/all_reward_stat.npy'
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
save_info = {
|
| 77 |
+
'train_dataset_path': train_dataset_path,
|
| 78 |
+
'base_model_name': base_model_name,
|
| 79 |
+
'tokenizer_name': tokenizer_name
|
| 80 |
+
}
|
| 81 |
+
for i in range(len(reward_model_path_list)):
|
| 82 |
+
save_info['reward_peft_path{}'.format(i+1)] = reward_model_path_list[i]
|
| 83 |
+
save_configs(save_info, os.path.join(script_args.save_directory, script_args.wandb_name))
|
| 84 |
+
|
| 85 |
+
save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
|
| 86 |
+
os.makedirs(save_path, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
## offline training
|
| 89 |
+
dataset = train_model(
|
| 90 |
+
base_model_name=base_model_name,
|
| 91 |
+
reward_model_path_list=reward_model_path_list,
|
| 92 |
+
train_dataset=train_dataset_path,
|
| 93 |
+
save_path=save_path + '/model_iter0',
|
| 94 |
+
tokenizer_name=tokenizer_name,
|
| 95 |
+
rm_tokenizer_path_list=rm_tokenizer_path_list,
|
| 96 |
+
training_steps=script_args.training_steps,
|
| 97 |
+
learning_rate=1.414e-4, # use larger lr for offline training than online training (script_args.learning_rate)
|
| 98 |
+
args=script_args,
|
| 99 |
+
exp_type=exp_type,
|
| 100 |
+
)
|
| 101 |
+
clean_gpu_memory()
|
| 102 |
+
|
| 103 |
+
online_dataset = None
|
| 104 |
+
model_path = base_model_name
|
| 105 |
+
for i in range(script_args.num_online_iterations):
|
| 106 |
+
print('iteration {} ...'.format(i))
|
| 107 |
+
checkpoint_path = os.path.join(save_path, 'model_iter{}'.format(i))
|
| 108 |
+
if i == 0 and script_args.training_steps == 0:
|
| 109 |
+
## skip the offline training from a saved lora parameter
|
| 110 |
+
peft_name = script_args.peft_name
|
| 111 |
+
else:
|
| 112 |
+
peft_name = checkpoint_path
|
| 113 |
+
|
| 114 |
+
# ### generation
|
| 115 |
+
if script_args.num_generation_samples > 0 and not os.path.exists(os.path.join(checkpoint_path, 'data.csv')):
|
| 116 |
+
generate_data(
|
| 117 |
+
model_path,
|
| 118 |
+
reward_model_path_list=reward_model_path_list,
|
| 119 |
+
tokenizer_name=tokenizer_name,
|
| 120 |
+
rm_tokenizer_path_list=rm_tokenizer_path_list,
|
| 121 |
+
dataset=dataset,
|
| 122 |
+
save_path=os.path.join(save_path, 'model_iter{}'.format(i)),
|
| 123 |
+
peft_name=peft_name,
|
| 124 |
+
reward_stats_path=reward_stats_path,
|
| 125 |
+
iter=i,
|
| 126 |
+
args=script_args,
|
| 127 |
+
exp_type=exp_type,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
clean_gpu_memory()
|
| 131 |
+
### merging original data and generated data
|
| 132 |
+
info_path = os.path.join(script_args.save_directory, script_args.wandb_name,'reward_info.npy')
|
| 133 |
+
merged_data, online_dataset = merge_dataset(dataset, online_dataset, checkpoint_path, tokenizer_name,
|
| 134 |
+
info_path=info_path, exp_type=exp_type, quantile_threshold=script_args.quantile_threshold,
|
| 135 |
+
sample_origin=script_args.num_origin_samples)
|
| 136 |
+
|
| 137 |
+
train_model(
|
| 138 |
+
base_model_name=model_path,
|
| 139 |
+
peft_name=peft_name,
|
| 140 |
+
reward_model_path_list=reward_model_path_list,
|
| 141 |
+
train_dataset=merged_data,
|
| 142 |
+
save_path=save_path + '/model_iter{}'.format(i+1),
|
| 143 |
+
tokenizer_name=tokenizer_name,
|
| 144 |
+
rm_tokenizer_path_list=rm_tokenizer_path_list,
|
| 145 |
+
training_steps=script_args.online_training_steps,
|
| 146 |
+
learning_rate=script_args.learning_rate,
|
| 147 |
+
lr_scheduler_type='constant',
|
| 148 |
+
iter=i+1,
|
| 149 |
+
args=script_args,
|
| 150 |
+
exp_type=exp_type,
|
| 151 |
+
)
|
| 152 |
+
clean_gpu_memory()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
ric/multi_reward_models.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from utils import load_reward_model, get_rewards
|
| 6 |
+
|
| 7 |
+
class RewardModels():
|
| 8 |
+
def __init__(self, reward_model_path_list, rm_tokenizer_path_list, gpu_id_list, reward_stats_path=None):
|
| 9 |
+
assert len(reward_model_path_list) == len(rm_tokenizer_path_list)
|
| 10 |
+
self.reward_model_path_list = reward_model_path_list
|
| 11 |
+
self.rm_tokenizer_path_list = rm_tokenizer_path_list
|
| 12 |
+
self.num_rewards = len(reward_model_path_list)
|
| 13 |
+
self.reward_stats = np.load(reward_stats_path) if reward_stats_path is not None else None
|
| 14 |
+
self.reward_models = []
|
| 15 |
+
self.rm_tokenizers = []
|
| 16 |
+
if type(gpu_id_list) != list:
|
| 17 |
+
gpu_id_list = [gpu_id_list, gpu_id_list, gpu_id_list]
|
| 18 |
+
|
| 19 |
+
print('Loading reward models .....')
|
| 20 |
+
for i in range(self.num_rewards):
|
| 21 |
+
self.reward_models.append(load_reward_model(self.reward_model_path_list[i], gpu_id_list[i]))
|
| 22 |
+
self.rm_tokenizers.append(AutoTokenizer.from_pretrained(self.rm_tokenizer_path_list[i]))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_reward_model_scores(self, queries_responses, summary_fun=None):
|
| 26 |
+
texts_for_rewards = []
|
| 27 |
+
for i in range(self.num_rewards):
|
| 28 |
+
if i >= 1 and self.rm_tokenizer_path_list[i] == self.rm_tokenizer_path_list[i-1]:
|
| 29 |
+
texts_for_rewards.append(texts_for_rewards[-1])
|
| 30 |
+
elif 'faithful' in self.reward_model_path_list[i]:
|
| 31 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 32 |
+
temp_encoded_texts = [self.rm_tokenizers[i](text=r, text_pair=summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 33 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 34 |
+
elif 'summary' in self.reward_model_path_list[i] or 'summarization' in self.reward_model_path_list[i]: # reverse prompt and response
|
| 35 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 36 |
+
temp_encoded_texts = [self.rm_tokenizers[i](r + " " + self.rm_tokenizers[i].bos_token + " " + summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 37 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 38 |
+
elif 'humor' in self.reward_model_path_list[i]: # use only response
|
| 39 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 40 |
+
temp_encoded_texts = [self.rm_tokenizers[i](r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 41 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 42 |
+
else:
|
| 43 |
+
max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
|
| 44 |
+
temp_encoded_texts = [self.rm_tokenizers[i](q, r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
|
| 45 |
+
texts_for_rewards.append(temp_encoded_texts)
|
| 46 |
+
|
| 47 |
+
# normalize reward
|
| 48 |
+
rewards = []
|
| 49 |
+
for i in range(self.num_rewards):
|
| 50 |
+
if self.reward_stats is not None:
|
| 51 |
+
if type(self.reward_stats) == list or len(self.reward_stats) == 2 * self.num_rewards:
|
| 52 |
+
reward_mean_std = (self.reward_stats[2*i], self.reward_stats[2*i+1])
|
| 53 |
+
else:
|
| 54 |
+
reward_mean_std = self.reward_stats[i]
|
| 55 |
+
else:
|
| 56 |
+
reward_mean_std = None
|
| 57 |
+
|
| 58 |
+
if 'humor' in self.reward_model_path_list[i] or 'faithful' in self.reward_model_path_list[i]:
|
| 59 |
+
temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std, sub_position=1)
|
| 60 |
+
else:
|
| 61 |
+
temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std)
|
| 62 |
+
rewards.append(temp_reward)
|
| 63 |
+
return rewards
|
| 64 |
+
|