hour1 commited on
Commit
f5d1134
·
verified ·
1 Parent(s): 2e2190c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +10 -0
  3. .vscode/settings.json +5 -0
  4. README.md +110 -0
  5. datasets/hh-rlhf/.gitattributes +54 -0
  6. datasets/hh-rlhf/README.md +68 -0
  7. datasets/hh-rlhf/harmless-base/test.jsonl.gz +3 -0
  8. datasets/hh-rlhf/harmless-base/train.jsonl.gz +3 -0
  9. datasets/hh-rlhf/helpful-base/test.jsonl.gz +3 -0
  10. datasets/hh-rlhf/helpful-base/train.jsonl.gz +3 -0
  11. datasets/hh-rlhf/helpful-online/test.jsonl.gz +3 -0
  12. datasets/hh-rlhf/helpful-online/train.jsonl.gz +3 -0
  13. datasets/hh-rlhf/helpful-rejection-sampled/test.jsonl.gz +3 -0
  14. datasets/hh-rlhf/helpful-rejection-sampled/train.jsonl.gz +3 -0
  15. datasets/hh-rlhf/red-team-attempts/red_team_attempts.jsonl.gz +3 -0
  16. export_hf_checkpoint.py +66 -0
  17. fix_trl/utils.py +739 -0
  18. plot.py +103 -0
  19. ppo/__pycache__/multi_reward_models.cpython-310.pyc +0 -0
  20. ppo/__pycache__/utils.cpython-310.pyc +0 -0
  21. ppo/eval_ppo_single_model.py +177 -0
  22. ppo/eval_rewarded_soups.py +220 -0
  23. ppo/morlhf-llama3.py +293 -0
  24. ppo/morlhf.py +277 -0
  25. ppo/multi_reward_models.py +64 -0
  26. ppo/ppo.py +313 -0
  27. ppo/ppo_reft.py +313 -0
  28. ppo/ppo_reft_fine_grained.py +336 -0
  29. ppo/scripts/load.sh +31 -0
  30. ppo/scripts/morlhf_llama3-inst.sh +15 -0
  31. ppo/scripts/morlhf_llama3-preference.sh +37 -0
  32. ppo/scripts/morlhf_llama3-preference1.sh +32 -0
  33. ppo/scripts/morlhf_llama3-preference2.sh +32 -0
  34. ppo/scripts/morlhf_llama3.sh +17 -0
  35. ppo/scripts/ppo_reft-fine_grained.sh +17 -0
  36. ppo/scripts/ppo_reft-fine_grained2.sh +18 -0
  37. ppo/scripts/ppo_reft.sh +18 -0
  38. ppo/scripts/ppo_single_vector.sh +17 -0
  39. ppo/scripts/run.sh +2 -0
  40. ppo/scripts/test.sh +17 -0
  41. ppo/test.ipynb +248 -0
  42. ppo/test_batch_decode.ipynb +294 -0
  43. ppo/utils.py +411 -0
  44. push_model.sh +8 -0
  45. push_with_hf.sh +1 -0
  46. requirements.txt +22 -0
  47. ric/evaluation.py +234 -0
  48. ric/generation.py +156 -0
  49. ric/main.py +155 -0
  50. ric/multi_reward_models.py +64 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ric_txt2img/docs/source/en/imgs/access_request.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ric_txt2img/results/*
2
+ ric_txt2img/pretrained/*
3
+ images/*
4
+ ric_txt2img/requirements.txt
5
+ ric_txt2img/eval_jpeg_aes_tmp.sh
6
+ download_data.py
7
+ check_data.py
8
+ process_data.py
9
+ ric_txt2img/data/*
10
+ ric_txt2img/experiments/
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RiC: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment
2
+
3
+ Code for the ICML 2024 paper "Rewards-in-Context: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment". This repo is based on [trl](https://github.com/huggingface/trl).
4
+
5
+ This readme file is for LLM tasks, please refer to the 'ric_txt2img' directory for the code for conditional diffusion model training.
6
+
7
+ **Note:** We have updated the code to include the EOS token in data preparation for SFT and RiC, which helps resolve the issue of non-ending generation. As a result, the model can produce more natural responses, although the results may differ from the paper. If you wish to achieve the exact performance in the paper, please use an old version prior to November 3, 2024.
8
+
9
+ ## 1. Installation
10
+ Install the requirements.
11
+ ```
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ The following issue has been fixed using token_ids directly for response_template as in https://huggingface.co/docs/trl/main/en/sft_trainer#train-on-completions-only.
16
+ ```
17
+ Fix the trl package.
18
+
19
+
20
+ The trl package currently has an issue with its 'DataCollatorForCompletionOnlyLM' implementation, wherein the first token of the template may be nonsensical and unmatched. Hence, it is necessary to make modifications to the trl/trainer/utils.py file within your trl package (usually under the path {path to conda or python}/lib/python3/site-packages), specifically the 'torch_call' function starting from line 119. Upon installing the trl package based on the requirements.txt, you can easily replace the trl/trainer/utils.py with the utils.py file located in the 'fix_trl' directory.
21
+ ```
22
+
23
+ ## 2. Usage
24
+ ### 2.1 RiC
25
+ ```
26
+ Note: To reproduce the results in the paper, ensure the same training steps, for example, (4 processes, batchsize=2, steps=20000), and load the model in 8bit (--load_in_8bit True). We have observed notable performance differences between 8-bit and 16-bit models, particularly in the summary task.
27
+ ```
28
+
29
+
30
+ * **Preparing datasets**: the first step is to generate dataset for RiC training.
31
+ ```
32
+ cd ./ric
33
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch prepare_dataset_with_rewards.py --reward_names 'harmless,helpful' --exp_type 'assistant' --save_directory './datasets/train_harmhelp.hf'
34
+ ```
35
+ ```
36
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch prepare_dataset_with_rewards.py --reward_names 'summary,faithful' --exp_type 'summary' --save_directory './datasets/summary_pref1faithful.hf'
37
+ ```
38
+ Here, 'exp_type' includes two types of tasks, i.e., 'assistant' and 'summary'. Moreover, 'reward_names' can be sampled from \{'harmless', 'helpful', 'humor', 'summary', 'deberta', 'faithful'\}, where the first three are used for the 'assistant' task while the last three are used for the 'summary' task.
39
+
40
+ * **Training**: you can train RiC through the main.py file.
41
+
42
+ Offline training:
43
+ ```
44
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch main.py --train_dataset_path {path_to_corresponding_dataset} --exp_type 'assistant' --reward_names 'harmless,helpful' --training_steps 20000 --num_online_iterations 0 --wandb_name 'ric_assistant_harmlesshelpful_offline20000' --batch_size 2 --load_in_8bit True
45
+ ```
46
+ Offline training + online training:
47
+ ```
48
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch main.py --train_dataset_path {path_to_corresponding_dataset} --exp_type 'assistant' --reward_names 'harmless,helpful' --training_steps 20000 --online_training_steps 4000 --num_online_iterations 2 --wandb_name 'ric_assistant_harmlesshelpful_offline20000_onlineiter2' --batch_size 2 --load_in_8bit True
49
+ ```
50
+
51
+ * **Evaluation**: After training, models are saved into 'save_directory' path. You can evaluate models using the evaluation.py.
52
+ ```
53
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch evaluation.py --peft_name {path_to_the_saved_peft} --reward_names 'harmless,helpful' --exp_type 'assistant' --wandb_name {path_of_the_experiment}
54
+ ```
55
+
56
+ ### 2.2 SFT
57
+ * Training a SFT model:
58
+ ```
59
+ cd ./sft
60
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch sft.py --base_model_name 'meta-llama/Llama-2-7b-hf' --exp_type 'summary' --wandb_name {name_of_the_experiment}
61
+ ```
62
+ * Evaluating a SFT model:
63
+ ```
64
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch eval_multiobjective.py --base_model_name {path_to_the_saved_SFT_model} --exp_type 'summary' --wandb_name {name_of_the_experiment} --reward_names 'summary,faithful'
65
+ ```
66
+
67
+ * Merging the SFT model if using lora for finetuning:
68
+ ```
69
+ python3 merge_sft_lora.py --base_model_name 'meta-llama/Llama-2-7b-hf' --lora_name {path_to_the_lora_file} --save_directory {path_to_the_saving_directory}
70
+ ```
71
+
72
+ ### 2.3 MORLHF
73
+ MORLHF optimizes preference weighted rewards using PPO, based on the SFT model in 2.2.
74
+
75
+ Train MORLHF:
76
+ ```
77
+ cd ./ppo
78
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch morlhf.py --preference 0.5 --base_model_name {path_to_the_sft_model} --reward_names 'harmless,helpful' --exp_type 'assistant' --wandb_name 'morlhf_llamma2'
79
+ ```
80
+ Here, the 'preference' is the preference for the first reward.
81
+
82
+ Evaluate MORLHF model:
83
+ ```
84
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch eval_ppo_single_model.py --base_model_name {path_to_the_saved_morlhf_model} --reward_names 'harmless,helpful' --exp_type 'assistant' --wandb_name 'eval_morlhf_llamma2'
85
+ ```
86
+
87
+ ### 2.4 Rewarded Soups
88
+ Rewarded Soups train $N$ PPO models for $N$ reward models.
89
+
90
+ Train PPO model for each reward model:
91
+ ```
92
+ cd ./ppo
93
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch ppo.py --reward_name 'harmless' --base_model_name {path_to_the_SFT_model} --exp_type 'assistant' --wandb_name {name_for_this_experiment}
94
+ ```
95
+ Evaluate Rewarded Soups with multiple PPO models:
96
+ ```
97
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch eval_rewarded_soups.py --reward_names 'harmless,helpful' --base_model_path1 {path_to_the_PPO_model_for_harmless} --base_model_path2 {path_to_the_PPO_model_for_helpful} --exp_type 'assistant' --wandb_name 'eval_pposoups_harmless_helpful'
98
+ ```
99
+
100
+
101
+ ## 3. Citation
102
+ If you find our work useful for your research, please cite:
103
+ ```
104
+ @article{yang2024rewards,
105
+ title={Rewards-in-Context: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment},
106
+ author={Yang, Rui and Pan, Xiaoman and Luo, Feng and Qiu, Shuang and Zhong, Han and Yu, Dong and Chen, Jianshu},
107
+ journal={International Conference on Machine Learning},
108
+ year={2024}
109
+ }
110
+ ```
datasets/hh-rlhf/.gitattributes ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ # Audio files - uncompressed
37
+ *.pcm filter=lfs diff=lfs merge=lfs -text
38
+ *.sam filter=lfs diff=lfs merge=lfs -text
39
+ *.raw filter=lfs diff=lfs merge=lfs -text
40
+ # Audio files - compressed
41
+ *.aac filter=lfs diff=lfs merge=lfs -text
42
+ *.flac filter=lfs diff=lfs merge=lfs -text
43
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
44
+ *.ogg filter=lfs diff=lfs merge=lfs -text
45
+ *.wav filter=lfs diff=lfs merge=lfs -text
46
+ # Image files - uncompressed
47
+ *.bmp filter=lfs diff=lfs merge=lfs -text
48
+ *.gif filter=lfs diff=lfs merge=lfs -text
49
+ *.png filter=lfs diff=lfs merge=lfs -text
50
+ *.tiff filter=lfs diff=lfs merge=lfs -text
51
+ # Image files - compressed
52
+ *.jpg filter=lfs diff=lfs merge=lfs -text
53
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
54
+ *.webp filter=lfs diff=lfs merge=lfs -text
datasets/hh-rlhf/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - human-feedback
5
+ ---
6
+
7
+ # Dataset Card for HH-RLHF
8
+
9
+ ## Dataset Summary
10
+
11
+
12
+ This repository provides access to two different kinds of data:
13
+ 1. Human preference data about helpfulness and harmlessness from [Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback](https://arxiv.org/abs/2204.05862). These data are meant to train preference (or reward) models for subsequent RLHF training. These data are *not* meant for supervised training of dialogue agents. Training dialogue agents on these data is likely to lead to harmful models and this shold be avoided.
14
+ 2. Human-generated and annotated red teaming dialogues from [Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned](https://www.anthropic.com/red_teaming.pdf). These data are meant to understand how crowdworkers red team models and what types of red team attacks are succesful or not. The data are *not* meant for fine-tuning or preference modeling (use the data above for preference modeling). These data are entire transcripts of conversations that are derived from the harmlessness preference modeling data described above, where only the chosen response is incorporated into the overall transcript. Furthermore, the transcripts are annotated with human and automated measurements of how harmful the overall dialogues are.
15
+
16
+ **Disclaimer**: The data (especially the harmlessness preference data and the red team data) contain content that may be offensive or upsetting. Topics include, but are not limited to, discriminatory language and discussions of abuse, violence, self-harm, exploitation, and other potentially upsetting subject matter. Please only engage with the data in accordance with your own personal risk tolerance. The data are intended for research purposes, especially research that can make models *less* harmful. The views expressed in the data do not reflect the views of Anthropic or any of its employees. As mentioned above, these data are *not* intended for training dialogue agents as this will likely lead to harmful model behavior.
17
+
18
+ Each of these datasets are described further below.
19
+
20
+ ## Human preference data about helpfulness and harmlessness (PM Data)
21
+
22
+ The data are described in the paper: [Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback](https://arxiv.org/abs/2204.05862). If you find the data useful, please cite the paper. The data format is very simple -- each line of the jsonl files contains a pair of texts, one "chosen" and one "rejected".
23
+
24
+ For **helpfulness**, the data are grouped into train/test splits in three tranches: from our base models (context-distilled 52B language models), via rejection sampling (mostly with best-of-16 sampling) against an early preference model, and a dataset sampled during our iterated "online" process.
25
+
26
+ For **harmlessness**, the data are only collected for our base models, but otherwise formatted in the same way.
27
+
28
+ Details about the data collection process and crowdworker population can be found in the paper, specifically in section 2 and appendix D.
29
+
30
+ ## Red teaming data (not PM Data)
31
+
32
+ The data are described in the paper: [Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned](https://www.anthropic.com/red_teaming.pdf). If you find the data useful, please cite the paper. At a high level these data can be examined to understand what succesful red team attempts look like.
33
+
34
+ Details about the data and data collection procedures can be found in the Datasheet in the appendix of the paper.
35
+
36
+ Each line of the jsonl file contains a dictionary with the following fields:
37
+ - `transcript` a text transcript of a conversation between a human adversary (red team member) and an AI assistant
38
+ - `min_harmlessness_score_transcript` a real value score of the harmlessness of the AI assistant (lower is more harmful) as obtained from a preference model
39
+ - `num_params` number of parameters in the language model powering the AI assistant
40
+ - `model_type` type of model powering the AI assistant
41
+ - `rating` the red team member's rating of how successful they were at breaking the AI assistant (Likert scale, higher is more successful)
42
+ - `task_description` a short text description written by the red team member about how they tried to red team the AI assistant
43
+ - `task_description_harmlessness_score` a real value score of the harmlessness of the task description (lower is more harmful) as obtained from a preference model
44
+ - `red_team_member_id` an arbitrary identifier of the red team member. one red team member can generate multiple red team attacks
45
+ - `is_upworker` a binary indicator that is true if the red team member was from the crowd platform Upwork or false if they were from MTurk
46
+ - `tags` a list of up to 6 tags per transcript. tags are short descriptions of the red team attempts generated by crowdworkers who reviewed red team data post-hoc. tags were only provided for a random sample of 1000 red team attempts for two of four model types.
47
+
48
+ ## Usage
49
+
50
+ Each of the above datasets is located in a separate sub-directory. To load an individual subset, use the `data_dir` argument of the `load_dataset()` function as follows:
51
+
52
+ ```python
53
+ from datasets import load_dataset
54
+
55
+ # Load all helpfulness/harmless subsets (share the same schema)
56
+ dataset = load_dataset("Anthropic/hh-rlhf")
57
+
58
+ # Load one of the harmless subsets
59
+ dataset = load_dataset("Anthropic/hh-rlhf", data_dir="harmless-base")
60
+
61
+ # Load the red teaming subset
62
+ dataset = load_dataset("Anthropic/hh-rlhf", data_dir="red-team-attempts")
63
+ ```
64
+
65
+ ## Contact
66
+
67
+ The original authors host this dataset on GitHub here: https://github.com/anthropics/hh-rlhf
68
+ You can submit inquiries to: redteam@anthropic.com
datasets/hh-rlhf/harmless-base/test.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebfaaed21162a4de120ae85366075b34425e2a6303fac59b99555002f7016f03
3
+ size 742801
datasets/hh-rlhf/harmless-base/train.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf3e1ecd3db61bf21d190f69351d26670c67aababa0fd282fa008c17a7202217
3
+ size 13197306
datasets/hh-rlhf/helpful-base/test.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8be3fc1a13b27901631696f2be6f184c799f1baeaf145f53ac5db24960adc37b
3
+ size 875191
datasets/hh-rlhf/helpful-base/train.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:518a5bf288456fc9f3b7c980c54116fba0c52f274d3f4d344675d83e4058f6f4
3
+ size 16200131
datasets/hh-rlhf/helpful-online/test.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39a475cc7e0df972902f8e3a0972726fce826ac65369bda8c671de4330c60a10
3
+ size 1053578
datasets/hh-rlhf/helpful-online/train.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0994d18b82083f5c084ce04071b683958a724df8c7215ab1d851986642d36a2a
3
+ size 20127228
datasets/hh-rlhf/helpful-rejection-sampled/test.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abc35dbc56015db1dae0faa484514ccbc3eee0133bfcd557fc69250619b028bb
3
+ size 1361231
datasets/hh-rlhf/helpful-rejection-sampled/train.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cea2d2ab2187624ec1314e8cc6fb579f0fa987b3a33e8102b7556ed4d255f64
3
+ size 25697148
datasets/hh-rlhf/red-team-attempts/red_team_attempts.jsonl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c7b0069991460f0064f279fd400b51f3f0095697d14d7793c49b0925f80814f
3
+ size 15483307
export_hf_checkpoint.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import transformers
5
+ from peft import PeftModel
6
+ from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer # noqa: F402
7
+
8
+ # BASE_MODEL = os.environ.get("BASE_MODEL", None)
9
+ # assert (
10
+ # BASE_MODEL
11
+ # ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
12
+
13
+
14
+ BASE_MODEL = '/home/hector5/models/Llama-3-Base-8B-SFT/'
15
+
16
+
17
+ # LORA_WEIGHT = '/home/hector5/RiC/ppo/ckpt_morlhf/morlhf_llamma3_lora-0.33_0.33_0.33/epoch_1/'
18
+ # LORA_WEIGHT = '/home/hector5/RiC/ppo/ckpt_morlhf/morlhf_llamma3_lora-0.2_0.2_0.6/epoch_1/'
19
+ LORA_WEIGHT = '/home/hector5/RiC/ppo/ckpt_morlhf/morlhf_llamma3_lora-0.0_0.0_1.0/epoch_0/'
20
+
21
+ OUTPUT_path = './merged_models/Llama-3-Base-8B-SFT-morlhf-humor'
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
24
+ tokenizer.save_pretrained(OUTPUT_path)
25
+
26
+ base_model = LlamaForCausalLM.from_pretrained(
27
+ BASE_MODEL,
28
+ load_in_8bit=False,
29
+ torch_dtype=torch.bfloat16,
30
+ device_map={"": "cpu"},
31
+ )
32
+
33
+ first_weight = base_model.model.layers[0].self_attn.q_proj.weight
34
+ first_weight_old = first_weight.clone()
35
+
36
+ lora_model = PeftModel.from_pretrained(
37
+ base_model,
38
+ LORA_WEIGHT,
39
+ device_map={"": "cpu"},
40
+ torch_dtype=torch.bfloat16,
41
+ )
42
+
43
+ lora_weight = lora_model.base_model.model.model.layers[
44
+ 0
45
+ ].self_attn.q_proj.weight
46
+
47
+ assert torch.allclose(first_weight_old, first_weight)
48
+
49
+ # merge weights - new merging method from peft
50
+ lora_model = lora_model.merge_and_unload()
51
+
52
+ lora_model.train(False)
53
+
54
+ # did we do anything?
55
+ assert not torch.allclose(first_weight_old, first_weight)
56
+
57
+ lora_model_sd = lora_model.state_dict()
58
+ deloreanized_sd = {
59
+ k.replace("base_model.model.", ""): v
60
+ for k, v in lora_model_sd.items()
61
+ if "lora" not in k
62
+ }
63
+
64
+ LlamaForCausalLM.save_pretrained(
65
+ base_model, OUTPUT_path, state_dict=deloreanized_sd, max_shard_size="2048MB"
66
+ )
fix_trl/utils.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import random
15
+ import warnings
16
+ from collections import deque
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ from accelerate import PartialState
23
+ from torch.nn.utils.rnn import pad_sequence
24
+ from torch.utils.data import IterableDataset
25
+ from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
26
+
27
+ from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
28
+ from ..trainer.model_config import ModelConfig
29
+
30
+
31
+ if is_peft_available():
32
+ from peft import LoraConfig, PeftConfig
33
+
34
+
35
+ class AdaptiveKLController:
36
+ """
37
+ Adaptive KL controller described in the paper:
38
+ https://arxiv.org/pdf/1909.08593.pdf
39
+ """
40
+
41
+ def __init__(self, init_kl_coef, target, horizon):
42
+ self.value = init_kl_coef
43
+ self.target = target
44
+ self.horizon = horizon
45
+
46
+ def update(self, current, n_steps):
47
+ target = self.target
48
+ proportional_error = np.clip(current / target - 1, -0.2, 0.2)
49
+ mult = 1 + proportional_error * n_steps / self.horizon
50
+ self.value *= mult
51
+
52
+
53
+ class FixedKLController:
54
+ """Fixed KL controller."""
55
+
56
+ def __init__(self, kl_coef):
57
+ self.value = kl_coef
58
+
59
+ def update(self, current, n_steps):
60
+ pass
61
+
62
+
63
+ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
64
+ """
65
+ Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
66
+ when they do not come from the assistant. This ensure that the loss is only
67
+ calculated on the completion made by the assistant.
68
+
69
+ Args:
70
+ response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
71
+ '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
72
+ differently if it does not have proper context.
73
+ instruction_template (`Union[str, List[int]]`): the template form that indicates the start of the human instruction, typically something like
74
+ '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
75
+ mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
76
+ `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
77
+ for flexibility and backwards-compatibility.
78
+ ignore_index (`int`, *optional*, defaults to `-100`):
79
+ The index to use to ignore the initial tokens with
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ response_template: Union[str, List[int]],
85
+ instruction_template: Optional[Union[str, List[int]]] = None,
86
+ *args,
87
+ mlm: bool = False,
88
+ ignore_index: int = -100,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(*args, mlm=mlm, **kwargs)
92
+
93
+ self.instruction_template = instruction_template
94
+ if isinstance(instruction_template, str):
95
+ # The user provides a string, must tokenize
96
+ self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
97
+ else:
98
+ # The user already provides the token ids
99
+ self.instruction_token_ids = instruction_template
100
+
101
+ self.response_template = response_template
102
+ if isinstance(response_template, str):
103
+ # The user provides a string, must tokenize
104
+ self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
105
+ else:
106
+ # The user already provides the token ids
107
+ self.response_token_ids = response_template
108
+
109
+ if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
110
+ warnings.warn(
111
+ "The pad_token_id and eos_token_id values of this tokenizer are identical. "
112
+ "If you are planning for multi-turn training, "
113
+ "it can result in the model continuously generating questions and answers without eos token. "
114
+ "To avoid this, set the pad_token_id to a different value."
115
+ )
116
+
117
+ self.ignore_index = ignore_index
118
+
119
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
120
+ batch = super().torch_call(examples)
121
+
122
+ if self.instruction_template is None:
123
+ for i in range(len(examples)):
124
+ response_token_ids_start_idx = None
125
+
126
+ for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
127
+ # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
128
+ if (
129
+ self.response_token_ids
130
+ == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
131
+ ):
132
+ response_token_ids_start_idx = idx
133
+
134
+
135
+ if response_token_ids_start_idx is None:
136
+ for idx in np.where(batch["labels"][i] == self.response_token_ids[1])[0]:
137
+ if (
138
+ self.response_token_ids[1:]
139
+ == batch["labels"][i][idx : idx + len(self.response_token_ids[1:])].tolist()
140
+ ):
141
+ response_token_ids_start_idx = idx
142
+
143
+ if response_token_ids_start_idx is None:
144
+ warnings.warn(
145
+ f"Could not find response key `{self.response_template}` in the "
146
+ f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
147
+ f"This instance will be ignored in loss calculation. "
148
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
149
+ )
150
+ batch["labels"][i, :] = self.ignore_index
151
+ else:
152
+ response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids[1:])
153
+ else:
154
+ response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
155
+
156
+ # Make pytorch loss function ignore all tokens up through the end of the response key
157
+ batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
158
+
159
+ else:
160
+ for i in range(len(examples)):
161
+ response_token_ids_idxs = []
162
+ human_token_ids_idxs = []
163
+
164
+ for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
165
+ # find the indexes of the start of a response.
166
+ if (
167
+ self.response_token_ids
168
+ == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
169
+ ):
170
+ response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
171
+
172
+ if len(response_token_ids_idxs) == 0:
173
+ warnings.warn(
174
+ f"Could not find response key `{self.response_template}` in the "
175
+ f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
176
+ f"This instance will be ignored in loss calculation. "
177
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
178
+ )
179
+ batch["labels"][i, :] = self.ignore_index
180
+
181
+ human_token_ids = self.instruction_token_ids
182
+ for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
183
+ # find the indexes of the start of a human answer.
184
+ if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
185
+ human_token_ids_idxs.append(human_idx)
186
+
187
+ if len(human_token_ids_idxs) == 0:
188
+ warnings.warn(
189
+ f"Could not find instruction key `{self.instruction_template}` in the "
190
+ f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
191
+ f"This instance will be ignored in loss calculation. "
192
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
193
+ )
194
+ batch["labels"][i, :] = self.ignore_index
195
+
196
+ if (
197
+ len(human_token_ids_idxs) > 0
198
+ and len(response_token_ids_idxs) > 0
199
+ and human_token_ids_idxs[0] > response_token_ids_idxs[0]
200
+ ):
201
+ human_token_ids_idxs = [0] + human_token_ids_idxs
202
+
203
+ for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
204
+ # Make pytorch loss function ignore all non response tokens
205
+ if idx != 0:
206
+ batch["labels"][i, start:end] = self.ignore_index
207
+ else:
208
+ batch["labels"][i, :end] = self.ignore_index
209
+
210
+ if len(response_token_ids_idxs) < len(human_token_ids_idxs):
211
+ batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
212
+
213
+ return batch
214
+
215
+
216
+ @dataclass
217
+ class RewardDataCollatorWithPadding:
218
+ r"""
219
+ Reward DataCollator class that pads the inputs to the maximum length of the batch.
220
+ Args:
221
+ tokenizer (`PreTrainedTokenizerBase`):
222
+ The tokenizer used for encoding the data.
223
+ padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
224
+ padding_strategy to pass to the tokenizer.
225
+ max_length (`Optional[int]`, `optional`, defaults to `None`):
226
+ The maximum length of the sequence to be processed.
227
+ pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`):
228
+ If set will pad the sequence to a multiple of the provided value.
229
+ return_tensors (`str`, `optional`, defaults to `"pt"`):
230
+ The tensor type to use.
231
+ """
232
+
233
+ tokenizer: PreTrainedTokenizerBase
234
+ padding: Union[bool, str] = True
235
+ max_length: Optional[int] = None
236
+ pad_to_multiple_of: Optional[int] = None
237
+ return_tensors: str = "pt"
238
+
239
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
240
+ features_chosen = []
241
+ features_rejected = []
242
+ margin = []
243
+ # check if we have a margin. If we do, we need to batch it as well
244
+ has_margin = "margin" in features[0]
245
+ for feature in features:
246
+ # check if the keys are named as expected
247
+ if (
248
+ "input_ids_chosen" not in feature
249
+ or "input_ids_rejected" not in feature
250
+ or "attention_mask_chosen" not in feature
251
+ or "attention_mask_rejected" not in feature
252
+ ):
253
+ raise ValueError(
254
+ "The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
255
+ )
256
+
257
+ features_chosen.append(
258
+ {
259
+ "input_ids": feature["input_ids_chosen"],
260
+ "attention_mask": feature["attention_mask_chosen"],
261
+ }
262
+ )
263
+ features_rejected.append(
264
+ {
265
+ "input_ids": feature["input_ids_rejected"],
266
+ "attention_mask": feature["attention_mask_rejected"],
267
+ }
268
+ )
269
+ if has_margin:
270
+ margin.append(feature["margin"])
271
+ batch_chosen = self.tokenizer.pad(
272
+ features_chosen,
273
+ padding=self.padding,
274
+ max_length=self.max_length,
275
+ pad_to_multiple_of=self.pad_to_multiple_of,
276
+ return_tensors=self.return_tensors,
277
+ )
278
+ batch_rejected = self.tokenizer.pad(
279
+ features_rejected,
280
+ padding=self.padding,
281
+ max_length=self.max_length,
282
+ pad_to_multiple_of=self.pad_to_multiple_of,
283
+ return_tensors=self.return_tensors,
284
+ )
285
+ batch = {
286
+ "input_ids_chosen": batch_chosen["input_ids"],
287
+ "attention_mask_chosen": batch_chosen["attention_mask"],
288
+ "input_ids_rejected": batch_rejected["input_ids"],
289
+ "attention_mask_rejected": batch_rejected["attention_mask"],
290
+ "return_loss": True,
291
+ }
292
+ if has_margin:
293
+ margin = torch.tensor(margin, dtype=torch.float)
294
+ batch["margin"] = margin
295
+ return batch
296
+
297
+
298
+ @dataclass
299
+ class DPODataCollatorWithPadding:
300
+ r"""
301
+ DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
302
+ Args:
303
+ pad_token_id (`int` defaults to 0):
304
+ The tokenizer's pad_token_id.
305
+ label_pad_token_id (`int`, defaults to -100):
306
+ The label used for masking.
307
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
308
+ Whether or not you model has an encoder_decoder architecture.
309
+ """
310
+
311
+ pad_token_id: int = 0
312
+ label_pad_token_id: int = -100
313
+ is_encoder_decoder: Optional[bool] = False
314
+
315
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
316
+ # first, pad everything to the same length
317
+ padded_batch = {}
318
+ for k in features[0].keys():
319
+ if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
320
+ if self.is_encoder_decoder:
321
+ to_pad = [torch.LongTensor(ex[k]) for ex in features]
322
+
323
+ if (k.startswith("prompt")) and (k.endswith("input_ids")):
324
+ if self.pad_token_id is None:
325
+ raise ValueError(
326
+ "Padding is enabled, but the tokenizer is not configured with a padding token."
327
+ " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
328
+ " before calling the trainer."
329
+ )
330
+ padding_value = self.pad_token_id
331
+ elif k.endswith("_attention_mask"):
332
+ padding_value = 0
333
+ elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
334
+ padding_value = self.label_pad_token_id
335
+ else:
336
+ raise ValueError(f"Unexpected key in batch '{k}'")
337
+ padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
338
+ else:
339
+ # adapted from https://stackoverflow.com/questions/73256206
340
+ if "prompt" in k:
341
+ to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features]
342
+ else:
343
+ to_pad = [torch.LongTensor(ex[k]) for ex in features]
344
+ if k.endswith("_input_ids"):
345
+ if self.pad_token_id is None:
346
+ raise ValueError(
347
+ "Padding is enabled, but the tokenizer is not configured with a padding token."
348
+ " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
349
+ " before calling the trainer."
350
+ )
351
+ padding_value = self.pad_token_id
352
+ elif k.endswith("_labels"):
353
+ padding_value = self.label_pad_token_id
354
+ elif k.endswith("_attention_mask"):
355
+ padding_value = 0
356
+ else:
357
+ raise ValueError(f"Unexpected key in batch '{k}'")
358
+
359
+ padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
360
+ # for the prompt, flip back so padding is on left side
361
+ if "prompt" in k:
362
+ padded_batch[k] = padded_batch[k].flip(dims=[1])
363
+ elif k.endswith("_logps"):
364
+ # the cached reference model logprobs
365
+ padded_batch[k] = torch.tensor([ex[k] for ex in features])
366
+ else:
367
+ padded_batch[k] = [ex[k] for ex in features]
368
+
369
+ return padded_batch
370
+
371
+
372
+ class ConstantLengthDataset(IterableDataset):
373
+ """
374
+ Iterable dataset that returns constant length chunks of tokens from stream of text files.
375
+ The dataset also formats the text before tokenization with a specific format that is provided
376
+ by the user.
377
+
378
+ Args:
379
+ tokenizer (`transformers.PreTrainedTokenizer`):
380
+ The processor used for processing the data.
381
+ dataset (`dataset.Dataset`):
382
+ Dataset with text files.
383
+ dataset_text_field (`str`, **optional**):
384
+ Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`.
385
+ formatting_func (`Callable`, **optional**):
386
+ Function that formats the text before tokenization. Usually it is recommended to have follows a certain
387
+ pattern such as `"### Question: {question} ### Answer: {answer}"`
388
+ infinite (`bool`, *optional*, defaults to `False`):
389
+ If True the iterator is reset after dataset reaches end else stops.
390
+ seq_length (`int`, *optional*, defaults to `1024`):
391
+ Length of token sequences to return.
392
+ num_of_sequences (`int`, *optional*, defaults to `1024`):
393
+ Number of token sequences to keep in buffer.
394
+ chars_per_token (`int`, *optional*, defaults to `3.6`):
395
+ Number of characters per token used to estimate number of tokens in text buffer.
396
+ eos_token_id (`int`, *optional*, defaults to `0`):
397
+ Id of the end of sequence token if the passed tokenizer does not have an EOS token.
398
+ shuffle ('bool', *optional*, defaults to True)
399
+ Shuffle the examples before they are returned
400
+ append_concat_token ('bool', *optional*, defaults to True)
401
+ If true, appends `eos_token_id` at the end of each sample being packed.
402
+ add_special_tokens ('bool', *optional*, defaults to True)
403
+ If true, tokenizers adds special tokens to each sample being packed.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ tokenizer,
409
+ dataset,
410
+ dataset_text_field=None,
411
+ formatting_func=None,
412
+ infinite=False,
413
+ seq_length=1024,
414
+ num_of_sequences=1024,
415
+ chars_per_token=3.6,
416
+ eos_token_id=0,
417
+ shuffle=True,
418
+ append_concat_token=True,
419
+ add_special_tokens=True,
420
+ ):
421
+ self.tokenizer = tokenizer
422
+
423
+ if tokenizer.eos_token_id is None:
424
+ warnings.warn(
425
+ "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds"
426
+ f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id."
427
+ )
428
+
429
+ self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id
430
+ self.dataset = dataset
431
+ self.seq_length = seq_length
432
+ self.infinite = infinite
433
+ self.current_size = 0
434
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
435
+ self.shuffle = shuffle
436
+ self.append_concat_token = append_concat_token
437
+ self.add_special_tokens = add_special_tokens
438
+ if formatting_func is None:
439
+ self.formatting_func = lambda x: x[dataset_text_field]
440
+ else:
441
+ self.formatting_func = formatting_func
442
+
443
+ if formatting_func is not None:
444
+ if formatting_func.__code__.co_argcount > 1:
445
+ warnings.warn(
446
+ "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
447
+ " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
448
+ )
449
+
450
+ def __len__(self):
451
+ return len(self.dataset)
452
+
453
+ def __iter__(self):
454
+ iterator = iter(self.dataset)
455
+ more_examples = True
456
+ while more_examples:
457
+ buffer, buffer_len = [], 0
458
+ while True:
459
+ if buffer_len >= self.max_buffer_size:
460
+ break
461
+ try:
462
+ buffer.append(self.formatting_func(next(iterator)))
463
+ buffer_len += len(buffer[-1])
464
+ except StopIteration:
465
+ if self.infinite:
466
+ iterator = iter(self.dataset)
467
+ warnings.warn("The dataset reached end and the iterator is reset to the start.")
468
+ else:
469
+ more_examples = False
470
+ break
471
+ tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
472
+ "input_ids"
473
+ ]
474
+ all_token_ids = []
475
+ for tokenized_input in tokenized_inputs:
476
+ if self.append_concat_token:
477
+ tokenized_input = tokenized_input + [self.concat_token_id]
478
+ all_token_ids.extend(tokenized_input)
479
+ examples = []
480
+ for i in range(0, len(all_token_ids), self.seq_length):
481
+ input_ids = all_token_ids[i : i + self.seq_length]
482
+ if len(input_ids) == self.seq_length:
483
+ examples.append(input_ids)
484
+ if self.shuffle:
485
+ random.shuffle(examples)
486
+ for example in examples:
487
+ self.current_size += 1
488
+ yield {
489
+ "input_ids": torch.LongTensor(example),
490
+ "labels": torch.LongTensor(example),
491
+ }
492
+
493
+
494
+ class RunningMoments:
495
+ def __init__(self, accelerator):
496
+ """
497
+ Calculates the running mean and standard deviation of a data stream. Reference:
498
+ https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
499
+ """
500
+ self.mean = 0
501
+ self.std = 1
502
+ self.var = 1
503
+ self.count = 1e-24
504
+ self.accelerator = accelerator
505
+
506
+ @torch.no_grad()
507
+ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
508
+ """
509
+ Updates running moments from batch's moments computed across ranks
510
+ """
511
+ if self.accelerator.use_distributed:
512
+ xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
513
+ else:
514
+ xs_count = xs.numel()
515
+ xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
516
+ xs_mean, xs_var = xs_mean.float(), xs_var.float()
517
+
518
+ delta = xs_mean - self.mean
519
+ tot_count = self.count + xs_count
520
+
521
+ new_sum = xs_var * xs_count
522
+ # correct old_sum deviation accounting for the new mean
523
+ old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
524
+ tot_sum = old_sum + new_sum
525
+
526
+ self.mean += delta * xs_count / tot_count
527
+ self.var = tot_sum / tot_count
528
+ self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
529
+ self.count = tot_count
530
+
531
+ return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
532
+
533
+
534
+ @torch.no_grad()
535
+ def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
536
+ """
537
+ Computes element-wise mean and variance of the tensor across processes. Reference:
538
+ https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
539
+ """
540
+ xs = xs.to(accelerator.device)
541
+ sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
542
+ sum_and_count = accelerator.reduce(sum_and_count)
543
+ global_sum, count = sum_and_count
544
+ global_mean = global_sum / count
545
+
546
+ sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
547
+ sum_var = accelerator.reduce(sum_var)
548
+ global_var = sum_var / count
549
+
550
+ return global_mean.to(device), global_var.to(device), count.to(device)
551
+
552
+
553
+ def compute_accuracy(eval_pred) -> Dict[str, float]:
554
+ predictions, labels = eval_pred
555
+ # Here, predictions is rewards_chosen and rewards_rejected.
556
+ # We want to see how much of the time rewards_chosen > rewards_rejected.
557
+ if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
558
+ warnings.warn(
559
+ f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
560
+ )
561
+ predictions = np.argmax(predictions, axis=1)
562
+
563
+ accuracy = np.array(predictions == labels, dtype=float).mean().item()
564
+ return {"accuracy": accuracy}
565
+
566
+
567
+ def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
568
+ if tensor.size(dim) >= length:
569
+ return tensor
570
+ else:
571
+ pad_size = list(tensor.shape)
572
+ pad_size[dim] = length - tensor.size(dim)
573
+ return torch.cat(
574
+ [
575
+ tensor,
576
+ pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
577
+ ],
578
+ dim=dim,
579
+ )
580
+
581
+
582
+ def disable_dropout_in_model(model: torch.nn.Module) -> None:
583
+ for module in model.modules():
584
+ if isinstance(module, torch.nn.Dropout):
585
+ module.p = 0
586
+
587
+
588
+ def exact_div(a, b, a_str, b_str, custom_error_message=""):
589
+ q = a // b
590
+ if a != q * b:
591
+ raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}")
592
+ return q
593
+
594
+
595
+ # copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5
596
+ class PerPromptStatTracker:
597
+ r"""
598
+ Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm
599
+
600
+ Args:
601
+ buffer_size (`int`):
602
+ Size of the buffer to keep for each prompt.
603
+ min_count (`int`):
604
+ Minimum number of samples to keep in the buffer before calculating the mean and std.
605
+ """
606
+
607
+ def __init__(self, buffer_size, min_count):
608
+ self.buffer_size = buffer_size
609
+ self.min_count = min_count
610
+ self.stats = {}
611
+
612
+ def update(self, prompts, rewards):
613
+ prompts = np.array(prompts)
614
+ rewards = np.array(rewards)
615
+ unique = np.unique(prompts)
616
+ advantages = np.empty_like(rewards)
617
+ for prompt in unique:
618
+ prompt_rewards = rewards[prompts == prompt]
619
+ if prompt not in self.stats:
620
+ self.stats[prompt] = deque(maxlen=self.buffer_size)
621
+ self.stats[prompt].extend(prompt_rewards)
622
+
623
+ if len(self.stats[prompt]) < self.min_count:
624
+ mean = np.mean(rewards)
625
+ std = np.std(rewards) + 1e-6
626
+ else:
627
+ mean = np.mean(self.stats[prompt])
628
+ std = np.std(self.stats[prompt]) + 1e-6
629
+ advantages[prompts == prompt] = (prompt_rewards - mean) / std
630
+
631
+ return advantages
632
+
633
+ def get_stats(self):
634
+ return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}
635
+
636
+
637
+ def neftune_post_forward_hook(module, input, output):
638
+ """
639
+ Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
640
+ torch.nn.Embedding layers. This method is slightly adapted from the original source code
641
+ that can be found here: https://github.com/neelsjain/NEFTune
642
+
643
+ Simply add it to your model as follows:
644
+ ```python
645
+ model = ...
646
+ model.embed_tokens.neftune_noise_alpha = 0.1
647
+ model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
648
+ ```
649
+
650
+ Args:
651
+ module (`torch.nn.Module`):
652
+ The embedding module where the hook is attached. Note that you need to set
653
+ `module.neftune_noise_alpha` to the desired noise alpha value.
654
+ input (`torch.Tensor`):
655
+ The input tensor to the model.
656
+ output (`torch.Tensor`):
657
+ The output tensor of the model (i.e. the embeddings).
658
+ """
659
+ if module.training:
660
+ dims = torch.tensor(output.size(1) * output.size(2))
661
+ mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
662
+ output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
663
+ return output
664
+
665
+
666
+ def peft_module_casting_to_bf16(model):
667
+ from peft.tuners.tuners_utils import BaseTunerLayer
668
+
669
+ for name, module in model.named_modules():
670
+ if isinstance(module, BaseTunerLayer):
671
+ module = module.to(torch.bfloat16)
672
+ elif isinstance(module, torch.nn.LayerNorm) or "norm" in name:
673
+ module = module.to(torch.float32)
674
+ elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
675
+ if hasattr(module, "weight"):
676
+ if module.weight.dtype == torch.float32:
677
+ module = module.to(torch.bfloat16)
678
+
679
+
680
+ def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
681
+ if is_unsloth_available():
682
+ # Unsloth adds a new attribute in the model config `unsloth_version`
683
+ # to keep track of models that have been patched with unsloth.
684
+ if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None:
685
+ tag_names.append("unsloth")
686
+
687
+ if kwargs is not None:
688
+ if "tags" not in kwargs:
689
+ kwargs["tags"] = tag_names
690
+ elif "tags" in kwargs and isinstance(kwargs["tags"], list):
691
+ kwargs["tags"].extend(tag_names)
692
+ elif "tags" in kwargs and isinstance(kwargs["tags"], str):
693
+ tag_names.append(kwargs["tags"])
694
+ kwargs["tags"] = tag_names
695
+ return kwargs
696
+
697
+
698
+ def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]:
699
+ if model_config.load_in_4bit:
700
+ quantization_config = BitsAndBytesConfig(
701
+ load_in_4bit=True,
702
+ bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
703
+ bnb_4bit_quant_type=model_config.bnb_4bit_quant_type,
704
+ bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant,
705
+ )
706
+ elif model_config.load_in_8bit:
707
+ quantization_config = BitsAndBytesConfig(
708
+ load_in_8bit=True,
709
+ )
710
+ else:
711
+ quantization_config = None
712
+
713
+ return quantization_config
714
+
715
+
716
+ def get_kbit_device_map() -> Optional[Dict[str, int]]:
717
+ if is_xpu_available():
718
+ return {"": f"xpu:{PartialState().local_process_index}"}
719
+ elif torch.cuda.is_available():
720
+ return {"": PartialState().local_process_index}
721
+ else:
722
+ return None
723
+
724
+
725
+ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]":
726
+ if model_config.use_peft is False:
727
+ return None
728
+
729
+ peft_config = LoraConfig(
730
+ r=model_config.lora_r,
731
+ lora_alpha=model_config.lora_alpha,
732
+ lora_dropout=model_config.lora_dropout,
733
+ bias="none",
734
+ task_type="CAUSAL_LM",
735
+ target_modules=model_config.lora_target_modules,
736
+ modules_to_save=model_config.lora_modules_to_save,
737
+ )
738
+
739
+ return peft_config
plot.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ import glob2
6
+ import seaborn as sns
7
+
8
+
9
+ def find_pareto_points(obtained_scores, threshold=0.02):
10
+ n = len(obtained_scores)
11
+ if n == 1:
12
+ return obtained_scores
13
+ pareto_index = []
14
+ high_low = np.max(obtained_scores, axis=0) - np.min(obtained_scores, axis=0)
15
+ for i in range(n):
16
+ if not any(np.all((obtained_scores - obtained_scores[i] - threshold * high_low) > 0.0, axis=1)):
17
+ pareto_index.append(i)
18
+
19
+ points = obtained_scores[np.array(pareto_index)]
20
+ arg_index = np.argsort(points[:, 0])
21
+ points = points[arg_index]
22
+ print(points)
23
+ sorted_index = [0]
24
+ remaining_index = np.ones(len(points))
25
+ i = 0
26
+ remaining_index[i] = 0
27
+ while sum(remaining_index):
28
+ distance = ((points[np.where(remaining_index)] - points[i]) ** 2 ).sum(axis=1)
29
+ min_index = np.where(remaining_index > 0)[0][np.argmin(distance)]
30
+ sorted_index.append(min_index)
31
+ i = min_index
32
+ remaining_index[i] = 0
33
+ return points[np.array(sorted_index)]
34
+
35
+
36
+
37
+ index = 1
38
+ colors = sns.color_palette('Paired')
39
+ def plot_points(dir, label, style='-*', color='b', shift=[0,0], txt_color='black', normalize_path=None, reverse=True):
40
+ threshold = 0.01
41
+ desired_scores = []
42
+ obtained_scores = []
43
+
44
+ paths = [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*.csv'))]
45
+ paths += [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*', '*.csv'))]
46
+
47
+ pref_lis = []
48
+ for path in paths:
49
+ if '.csv' in path:
50
+ full_path = path
51
+ data = pd.read_csv(full_path)
52
+ # morlhf has less points, let the threshold larger to make the frontier better
53
+ if 'ppo' in path and len(paths) <= 5:
54
+ threshold = 0.5
55
+ obtained_scores.append([np.mean(data['obtained_score1']), np.mean(data['obtained_score2'])])
56
+ if 'pref' in path:
57
+ # get the preference
58
+ if 'eval_data_pref' in path:
59
+ pref = path.split('eval_data_pref')[-1].strip().split('_')[0]
60
+ pref_lis.append(float(pref))
61
+
62
+ print(pref_lis)
63
+ desired_scores = np.array(desired_scores)
64
+ obtained_scores = np.array(obtained_scores)
65
+
66
+ if normalize_path is not None:
67
+ norm_info = np.load(normalize_path)
68
+ norm_info = np.array(norm_info).reshape(2, 2)
69
+ for i in range(2):
70
+ obtained_scores[:, i] = (obtained_scores[:, i] - norm_info[i][0]) / norm_info[i][1]
71
+
72
+ global index
73
+ markersize = 10 if ('*' in style or 'o' in style) else 9
74
+ pareto_points = find_pareto_points(obtained_scores, threshold)
75
+ plt.scatter(obtained_scores[:, 0], obtained_scores[:, 1], marker=style[-1], color=colors[index], s=markersize + 60)
76
+ if len(pref_lis):
77
+ for i in range(len(obtained_scores)):
78
+ plt.annotate('{}'.format(round(pref_lis[i], 1)), (obtained_scores[i, 0] + shift[0], obtained_scores[i, 1] + shift[1]), size=4, color=txt_color)
79
+
80
+ plt.plot(pareto_points[:, 0], pareto_points[:, 1], style, c=colors[index], markersize=markersize, label=label)
81
+ index += 2
82
+
83
+
84
+ plt.figure(figsize=(5, 4))
85
+
86
+ name1 = 'harmless'
87
+ name2 = 'helpful'
88
+
89
+ ### replace the paths to your own paths
90
+ plot_points('./logs_trl/eval_pretrained', 'Llama 2 base', '*')
91
+ plot_points('./logs_trl/eval_sft_alldata', 'SFT', '*')
92
+ plot_points('./eval_ppo_pref/', 'MORLHF', '--D', shift=[-0.012, -0.022])
93
+ plot_points('./logs_ppo/eval_pposoups_llamma2_klreg0.2', 'Rewarded Soups', style='--s', shift=[-0.012, -0.022])
94
+ plot_points('.logs_trl/evalnew_onlinefix_helpful_harmlesshelpful_iter2', 'RiC', style='-o', shift=[-0.012, -0.022], txt_color='white')
95
+
96
+
97
+ plt.xlabel('$R_1$ ({})'.format(name1), fontsize=12)
98
+ plt.ylabel('$R_2$ ({})'.format(name2), fontsize=12)
99
+ plt.legend(fontsize=11, loc='lower left')
100
+ plt.tight_layout()
101
+ plt.savefig('ric_assistant_{}_{}.pdf'.format(name1, name2))
102
+
103
+
ppo/__pycache__/multi_reward_models.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
ppo/__pycache__/utils.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
ppo/eval_ppo_single_model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+ from accelerate import Accelerator
5
+ import torch
6
+ from tqdm import tqdm
7
+ from transformers import AutoModelForCausalLM, HfArgumentParser
8
+ from transformers import DataCollatorWithPadding
9
+ from trl import set_seed
10
+ import numpy as np
11
+ import pandas as pd
12
+ from torch.utils.data import DataLoader
13
+ from utils import load_main_tokenizer
14
+ from multi_reward_models import RewardModels
15
+ from utils import Instructions, Instructions_summary, build_dataset_eval, build_dataset_summary_eval, check_lora_in_model_path, merge_lora_weight
16
+ tqdm.pandas()
17
+
18
+ # define paths for two datasets
19
+ hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
20
+ summary_dataset_path = 'openai/summarize_from_feedback'
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
25
+ save_directory: Optional[str] = field(default='./logs_ppo_assistant')
26
+ batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
27
+ wandb_name: Optional[str] = field(default='eval_morlhf_kl0.2_multirew0.2_gen128_harmless_helpful', metadata={"help": "Name for this experiment"})
28
+ reward_names:Optional[str] = field(default='harmless,helpful')
29
+ base_model_name: Optional[str] = field(default='./logs_ppo_prefnew/ppo_llamma2_klreg0.2_harmless_helpful_preference0.5_0.5/batch_207')
30
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
31
+
32
+ parser = HfArgumentParser(ScriptArguments)
33
+ script_args = parser.parse_args_into_dataclasses()[0]
34
+ exp_type = script_args.exp_type
35
+ base_model_name = script_args.base_model_name
36
+ tokenier_name = script_args.base_model_name
37
+ print('base model: ', base_model_name)
38
+
39
+ reward_names = [x.strip() for x in script_args.reward_names.split(',')]
40
+ num_rewards = len(reward_names)
41
+ print('number of rewards: {}'.format(num_rewards))
42
+ reward_path_tokenizer_dict = {
43
+ 'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
44
+ 'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
45
+ 'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
46
+ 'summary': ['Tristan/gpt2_reward_summarization'],
47
+ 'faithful':['CogComp/bart-faithful-summary-detector'],
48
+ 'humor': ['mohameddhiab/humor-no-humor'],
49
+ }
50
+ reward_model_path_list = []
51
+ rm_tokenizer_path_list = []
52
+ for name in reward_names:
53
+ if name not in reward_path_tokenizer_dict.keys():
54
+ raise NotImplementedError
55
+ reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
56
+ rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
57
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
58
+
59
+
60
+ process_id = Accelerator().local_process_index
61
+ gpu_id = process_id
62
+ print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
63
+
64
+
65
+ ##############
66
+ rm_gpu_id = gpu_id
67
+ reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, rm_gpu_id)
68
+
69
+ set_seed(8888)
70
+ current_device = Accelerator().local_process_index
71
+ print(current_device)
72
+
73
+
74
+ tokenizer = load_main_tokenizer(tokenier_name)
75
+ model = AutoModelForCausalLM.from_pretrained(
76
+ base_model_name,
77
+ torch_dtype=torch.bfloat16, # fast inference
78
+ device_map=gpu_id,
79
+ )
80
+ model.resize_token_embeddings(len(tokenizer))
81
+ model = merge_lora_weight(model, base_model_name)
82
+
83
+ generation_kwargs = {
84
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
85
+ "min_length": -1,
86
+ "top_k": 0.0,
87
+ "top_p": 0.9,
88
+ "do_sample": True,
89
+ }
90
+
91
+
92
+ def get_clean_data(full_responses, full_prompts):
93
+ full_responses_clean = []
94
+ for i, response in enumerate(full_responses):
95
+ full_prompts[i] = full_prompts[i].strip('[PAD] ')
96
+ response = response.strip('[PAD] ')
97
+ temp_resp = response.replace(full_prompts[i], '').strip('<s>').strip('</s>')
98
+ if '</s>' in temp_resp:
99
+ temp_resp = temp_resp[:temp_resp.rindex('</s>')]
100
+ temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
101
+ temp_resp = temp_resp.split('\nHuman:')[0].strip()
102
+ temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
103
+ temp_resp = temp_resp.split('\nAssistant:')[0].strip()
104
+ temp_resp = temp_resp.split('\n\n\n')[0].strip()
105
+ full_responses_clean.append(full_prompts[i] + ' ' + temp_resp)
106
+ return full_responses_clean
107
+
108
+
109
+ ### for evaluation
110
+ print('evaluation........')
111
+ tokenizer.padding_side = "left"
112
+ if exp_type == 'assistant':
113
+ valid_dataset = build_dataset_eval(hhrlhf_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
114
+ instructions = Instructions()
115
+ else:
116
+ valid_dataset = build_dataset_summary_eval(summary_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
117
+ instructions = Instructions_summary()
118
+ print(f"Size of the validation set: {len(valid_dataset)}")
119
+
120
+ valid_batch_size = 1
121
+ valid_dataset = valid_dataset.remove_columns(['prompt', 'query'])
122
+ for key in ['key', 'text', 'response']:
123
+ if key in valid_dataset.column_names:
124
+ valid_dataset = valid_dataset.remove_columns(key)
125
+
126
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
127
+ valid_data_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, drop_last=True, collate_fn=data_collator)
128
+ accelerator = Accelerator()
129
+ model, valid_data_loader = accelerator.prepare(model, valid_data_loader)
130
+
131
+ full_response_tensors = []
132
+ full_prompts = []
133
+
134
+ pbar = tqdm(total=len(valid_dataset) // valid_batch_size // accelerator.num_processes)
135
+ with torch.no_grad():
136
+ for i, batch in enumerate(valid_data_loader):
137
+ response_tensors = accelerator.unwrap_model(model).generate(batch['input_ids'], attention_mask=batch['attention_mask'], **generation_kwargs)
138
+ full_response_tensors.extend(response_tensors)
139
+ full_prompts.extend(batch['input_ids'])
140
+ pbar.update(1)
141
+
142
+ full_prompts = tokenizer.batch_decode(full_prompts)
143
+ full_responses = tokenizer.batch_decode(full_response_tensors)
144
+ full_responses = get_clean_data(full_responses, full_prompts)
145
+ # Compute score
146
+ queries_responses = [
147
+ (instructions.get_input(text), instructions.get_response(text))
148
+ for text in full_responses
149
+ ]
150
+
151
+
152
+ if hasattr(instructions, 'get_post'):
153
+ rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
154
+ else:
155
+ rewards_list = reward_models.get_reward_model_scores(queries_responses)
156
+
157
+ all_rewards = []
158
+ for i in range(reward_models.num_rewards):
159
+ all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
160
+ ### merge data
161
+ ### error here may because of old version of transformers/accelerate/peft
162
+ all_full_prompts = accelerator.gather_for_metrics(full_prompts)
163
+ all_full_responses = accelerator.gather_for_metrics(full_responses)
164
+
165
+
166
+ if process_id == 0:
167
+ evaluation_result = {
168
+ 'prompt':all_full_prompts,
169
+ 'response': all_full_responses,
170
+ }
171
+ for i in range(reward_models.num_rewards):
172
+ evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
173
+ print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
174
+
175
+ dataframe = pd.DataFrame(evaluation_result)
176
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data.csv'))
177
+
ppo/eval_rewarded_soups.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from accelerate import Accelerator
6
+ import torch
7
+ from datasets import load_dataset
8
+ from tqdm import tqdm
9
+ from transformers import HfArgumentParser
10
+ from transformers import AutoModelForCausalLM, DataCollatorWithPadding
11
+ from torch.utils.data import DataLoader
12
+ from trl import set_seed
13
+ import numpy as np
14
+ import pandas as pd
15
+ from utils import get_clean_data, load_main_tokenizer, save_configs, print_trainable_parameters, \
16
+ merge_weights_with_preference, Instructions, Instructions_summary, build_dataset_eval, build_dataset_summary_eval
17
+ from multi_reward_models import RewardModels
18
+ tqdm.pandas()
19
+
20
+
21
+ # define paths for two datasets
22
+ hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
23
+ summary_dataset_path = 'openai/summarize_from_feedback'
24
+ tokenizer_path = 'meta-llama/Llama-2-7b-hf'
25
+
26
+ @dataclass
27
+ class ScriptArguments:
28
+ save_directory: Optional[str] = field(default='./logs_rewardedsoups_summary_eval')
29
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "minibatch size for eval"})
30
+ wandb_name: Optional[str] = field(default='eval_pposoups_klreg0.2_harmless_helpful', metadata={"help": "Name for this experiment"})
31
+ reward_names:Optional[str] = field(default='harmless,helpful')
32
+ base_model_path1: Optional[str]=field(default='./ppo_llamma2_klreg0.2_harmless/batch_200')
33
+ base_model_path2: Optional[str]=field(default='./ppo_llamma2_klreg0.2_helpful/batch_200')
34
+ base_model_path3: Optional[str]=field(default='')
35
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
36
+
37
+ parser = HfArgumentParser(ScriptArguments)
38
+ script_args = parser.parse_args_into_dataclasses()[0]
39
+ exp_type = script_args.exp_type
40
+ base_model_name_1 = script_args.base_model_path1
41
+ base_model_name_2 = script_args.base_model_path2
42
+ base_model_name_3 = script_args.base_model_path3
43
+ tokenier_name = tokenizer_path
44
+
45
+ reward_names = [x.strip() for x in script_args.reward_names.split(',')]
46
+ reward_path_tokenizer_dict = {
47
+ 'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
48
+ 'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
49
+ 'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
50
+ 'summary': ['Tristan/gpt2_reward_summarization'],
51
+ 'faithful':['CogComp/bart-faithful-summary-detector'],
52
+ 'humor': ['mohameddhiab/humor-no-humor'],
53
+ }
54
+ reward_model_path_list = []
55
+ rm_tokenizer_path_list = []
56
+ for name in reward_names:
57
+ if name not in reward_path_tokenizer_dict.keys():
58
+ raise NotImplementedError
59
+ reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
60
+ rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
61
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
62
+ save_info = {
63
+ 'base_model_name_1': base_model_name_1,
64
+ 'base_model_name_2': base_model_name_2,
65
+ 'base_model_name_3': base_model_name_3,
66
+ 'reward_peft_path1': reward_model_path_list[0],
67
+ 'reward_peft_path2': reward_model_path_list[1],
68
+ 'tokenier_name': tokenier_name
69
+ }
70
+ for i in range(len(reward_model_path_list)):
71
+ save_info['reward_peft_path{}'.format(i+1)] = reward_model_path_list[i]
72
+ save_configs(save_info, os.path.join(script_args.save_directory, script_args.wandb_name))
73
+
74
+
75
+ accelerator = Accelerator()
76
+ process_id = Accelerator().local_process_index
77
+ gpu_id = process_id
78
+ print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
79
+ reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id)
80
+
81
+
82
+ set_seed(8888)
83
+ current_device = Accelerator().local_process_index
84
+ print(current_device)
85
+
86
+
87
+ tokenizer = load_main_tokenizer(tokenier_name)
88
+ if exp_type == 'assistant':
89
+ valid_dataset = build_dataset_eval(hhrlhf_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
90
+ instructions = Instructions()
91
+ else:
92
+ valid_dataset = build_dataset_summary_eval(summary_dataset_path, tokenizer, reward_models.rm_tokenizers, split='test')
93
+ instructions = Instructions_summary()
94
+ print(f"Size of the validation set: {len(valid_dataset)}")
95
+
96
+
97
+ valid_dataset = valid_dataset.remove_columns(['prompt', 'query'])
98
+ for key in ['key', 'text', 'response']:
99
+ if key in valid_dataset.column_names:
100
+ valid_dataset = valid_dataset.remove_columns(key)
101
+
102
+ def evaluate_model(temp_save_path, tokenizer, valid_dataset):
103
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
104
+ valid_data_loader = DataLoader(valid_dataset, batch_size=script_args.mini_batch_size, drop_last=True, collate_fn=data_collator)
105
+ ### load_merged_model
106
+ model = AutoModelForCausalLM.from_pretrained(
107
+ temp_save_path,
108
+ torch_dtype=torch.bfloat16, # for merging model
109
+ device_map=gpu_id
110
+ )
111
+ model.resize_token_embeddings(len(tokenizer))
112
+
113
+ accelerator = Accelerator()
114
+ model, valid_data_loader = accelerator.prepare(model, valid_data_loader)
115
+
116
+ generation_kwargs = {
117
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
118
+ "min_length": -1,
119
+ "top_k": 0.0,
120
+ "top_p": 0.9,
121
+ "do_sample": True,
122
+ }
123
+
124
+ full_responses = []
125
+ full_prompts = []
126
+ pbar = tqdm(total=len(valid_dataset) // script_args.mini_batch_size // accelerator.num_processes)
127
+ with torch.no_grad():
128
+ for i, batch in enumerate(valid_data_loader):
129
+ response_tensors = accelerator.unwrap_model(model).generate(batch["input_ids"], **generation_kwargs) #length_sampler=output_length_sampler,
130
+ full_responses.extend(response_tensors)
131
+ full_prompts.extend(batch['input_ids'])
132
+ pbar.update(1)
133
+
134
+ full_responses = tokenizer.batch_decode(full_responses)
135
+ full_prompts = tokenizer.batch_decode(full_prompts)
136
+ # clean data
137
+ full_prompts, full_responses = get_clean_data(full_responses, full_prompts)
138
+
139
+ queries_responses = [
140
+ (instructions.get_input(text), instructions.get_response(text))
141
+ for text in full_responses
142
+ ]
143
+ if hasattr(instructions, 'get_post'):
144
+ rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
145
+ else:
146
+ rewards_list = reward_models.get_reward_model_scores(queries_responses)
147
+
148
+ ### error here may because of old version of transformers/accelerate/peft
149
+ all_rewards = []
150
+ for i in range(reward_models.num_rewards):
151
+ all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
152
+ all_full_prompts = accelerator.gather_for_metrics(full_prompts)
153
+ all_full_responses = accelerator.gather_for_metrics(full_responses)
154
+ return all_rewards, all_full_prompts, all_full_responses
155
+
156
+
157
+ print("Evaluating........")
158
+ tokenizer.padding_side = "left"
159
+ ## preference list
160
+ if reward_models.num_rewards == 3:
161
+ preferences = np.array([
162
+ [0.0, 0.0, 1.0],
163
+ [0.0, 1.0, 0.0],
164
+ [0.1, 0.1, 0.8],
165
+ [0.1, 0.8, 0.1],
166
+ [0.2, 0.2, 0.6],
167
+ [0.2, 0.4, 0.4],
168
+ [0.2, 0.6, 0.2],
169
+ [0.33, 0.33, 0.33],
170
+ [0.4, 0.4, 0.2],
171
+ [0.4, 0.2, 0.4],
172
+ [0.6, 0.2, 0.2],
173
+ [0.8, 0.1, 0.1],
174
+ [1.0, 0.0, 0.0],
175
+ ])
176
+ elif reward_models.num_rewards == 2:
177
+ M = 10
178
+ preferences = np.zeros((M+1, 2))
179
+ preferences[:, 0] = np.arange(0,1+ 1/M,1 / M)
180
+ preferences[:, 1] = 1 - preferences[:, 0]
181
+ preferences = np.round(preferences, 1)
182
+ else:
183
+ raise NotImplementedError
184
+
185
+ for k in range(0, len(preferences)):
186
+ preference = preferences[k]
187
+ temp_save_path = './temp_models/{}'.format(script_args.wandb_name)
188
+ if process_id == 0:
189
+ if len(preference) == 3:
190
+ base_model_list = [base_model_name_1, base_model_name_2, base_model_name_3]
191
+ else:
192
+ base_model_list = [base_model_name_1, base_model_name_2]
193
+ merge_weights_with_preference(base_model_list, preference, temp_save_path)
194
+
195
+ accelerator.wait_for_everyone()
196
+ gc.collect()
197
+ torch.cuda.empty_cache()
198
+ print(k, preference)
199
+ all_rewards, all_full_prompts, all_full_responses = evaluate_model(temp_save_path, tokenizer, valid_dataset)
200
+ gc.collect()
201
+ torch.cuda.empty_cache()
202
+
203
+ if process_id == 0:
204
+ evaluation_result = {
205
+ 'prompt': all_full_prompts,
206
+ 'response': all_full_responses,
207
+ }
208
+ for i in range(reward_models.num_rewards):
209
+ evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
210
+ print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
211
+
212
+ dataframe = pd.DataFrame(evaluation_result)
213
+ if len(preference) == 2:
214
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}.csv'.format(preference[0], preference[1])))
215
+ else:
216
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}_{}.csv'.format(preference[0], preference[1], preference[2])))
217
+
218
+
219
+
220
+
ppo/morlhf-llama3.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from accelerate import Accelerator
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import AutoTokenizer, HfArgumentParser
9
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
10
+ import numpy as np
11
+ import pandas as pd
12
+ from utils import print_trainable_parameters, Instructions, Instructions_summary, \
13
+ load_llama3_tokenizer, build_dataset_llama3
14
+ from multi_reward_models import RewardModels
15
+ tqdm.pandas()
16
+ from peft import LoraConfig
17
+ import matplotlib.pyplot as plt
18
+
19
+ # define paths for two datasets
20
+ hhrlhf_dataset_path = '../datasets/hh-rlhf'
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
25
+ disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
26
+ save_directory: Optional[str] = field(default='./logs_morlhf/')
27
+ epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
28
+ learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
29
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
30
+ batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size64"})
31
+ gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
32
+ early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
33
+ target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"})
34
+ init_kl_coef: Optional[float] = field(default=0.2,metadata={"help": "0.05 Initial KL penalty coefficient (used for adaptive and linear control)"},)
35
+ max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
36
+ load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
37
+ preference: Optional[str] = field(default=None)
38
+ wandb_name: Optional[str] = field(default='morlhf_llamma2_klreg0.2', metadata={"help": "Name for this experiment"})
39
+ base_model_name: Optional[str] = field(default='./merged_sft_summary', metadata={'help':"the path to the sft model; need to merge if using lora"})
40
+ reward_names:Optional[str] = field(default='harmless,helpful,humor')
41
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
42
+
43
+ parser = HfArgumentParser(ScriptArguments)
44
+ script_args = parser.parse_args_into_dataclasses()[0]
45
+ exp_type = script_args.exp_type
46
+
47
+ tokenier_name = script_args.base_model_name
48
+ base_model_name = script_args.base_model_name
49
+ print('base model: ', base_model_name)
50
+
51
+ if script_args.disable_wandb: # if you don't need the wandb log
52
+ os.environ['WANDB_DISABLED'] = 'true'
53
+
54
+ reward_names = [x.strip() for x in script_args.reward_names.split(',')]
55
+ num_rewards = len(reward_names)
56
+ print('number of rewards: {}'.format(num_rewards))
57
+ #######################################################################
58
+ assert num_rewards == 3
59
+ if script_args.preference is None:
60
+ # 均分preference
61
+ preference = [round(1 / num_rewards, 2) for _ in range(num_rewards)]
62
+ else:
63
+ preference = [float(x) for x in script_args.preference.split(",")]
64
+ print('preference: {}'.format(preference))
65
+ #######################################################################
66
+ reward_path_tokenizer_dict = {
67
+ 'harmless': ['../reward_models/gpt2-large-harmless-reward_model'],
68
+ 'helpful': ['../reward_models/gpt2-large-helpful-reward_model'],
69
+ 'humor': ['../reward_models/humor-no-humor'],
70
+ }
71
+ reward_model_path_list = []
72
+ rm_tokenizer_path_list = []
73
+ for name in reward_names:
74
+ if name not in reward_path_tokenizer_dict.keys():
75
+ raise NotImplementedError
76
+ reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
77
+ rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
78
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
79
+
80
+
81
+ config = PPOConfig(
82
+ model_name=base_model_name,
83
+ learning_rate=script_args.learning_rate,
84
+ # log_with=script_args.log_with,
85
+ mini_batch_size=script_args.mini_batch_size,
86
+ batch_size=script_args.batch_size,
87
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
88
+ early_stopping=script_args.early_stopping,
89
+ target=script_args.target,
90
+ max_grad_norm=script_args.max_grad_norm,
91
+ optimize_cuda_cache=True,
92
+ init_kl_coef=script_args.init_kl_coef,
93
+ # tracker_project_name='morlhf',
94
+ # tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
95
+ remove_unused_columns=False # 防止‘reward_query’没了
96
+ )
97
+
98
+ accelerator = Accelerator()
99
+ process_id = Accelerator().local_process_index
100
+ gpu_id = process_id
101
+ print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
102
+
103
+
104
+ ############## load reward models
105
+ reward_model = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id)
106
+ rm_tokenizer = AutoTokenizer.from_pretrained(rm_tokenizer_path_list[0])
107
+
108
+ def collator(data):
109
+ return dict((key, [d[key] for d in data]) for key in data[0])
110
+
111
+ set_seed(8888)
112
+ current_device = Accelerator().local_process_index
113
+ print(current_device)
114
+
115
+ lora_config = LoraConfig(
116
+ r=64,
117
+ lora_alpha=128,
118
+ lora_dropout=0.05,
119
+ bias="none",
120
+ task_type="CAUSAL_LM",
121
+ )
122
+
123
+
124
+ tokenizer = load_llama3_tokenizer(tokenier_name)
125
+ if exp_type == 'assistant':
126
+ dataset = build_dataset_llama3(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
127
+ # instructions = Instructions()
128
+ else:
129
+ raise NotImplementedError
130
+ # dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
131
+ # instructions = Instructions_summary()
132
+ train_dataset = dataset.shuffle()
133
+ print(f"Size of the train set: {len(train_dataset)}")
134
+
135
+
136
+ if script_args.load_in_8bit:
137
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
138
+ base_model_name,
139
+ load_in_8bit=True,
140
+ peft_config=lora_config,
141
+ device_map=gpu_id,
142
+ )
143
+ else:
144
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
145
+ base_model_name,
146
+ torch_dtype=torch.bfloat16,
147
+ peft_config=lora_config,
148
+ device_map=gpu_id,
149
+ )
150
+
151
+ print_trainable_parameters(model)
152
+ model.pretrained_model.resize_token_embeddings(len(tokenizer))
153
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
154
+
155
+ ppo_trainer = PPOTrainer(
156
+ config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
157
+ )
158
+
159
+ # 只进行一轮生成
160
+ terminators = [
161
+ tokenizer.eos_token_id,
162
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
163
+ ]
164
+ generation_kwargs = {
165
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
166
+ "min_length": -1,
167
+ "top_k": 0.0,
168
+ "top_p": 1.0,
169
+ "do_sample": True,
170
+ "temperature": 0.7,
171
+ "pad_token_id": tokenizer.eos_token_id,
172
+ "begin_suppress_tokens": [tokenizer.eos_token_id] ,
173
+ 'eos_token_id': terminators
174
+ }
175
+
176
+
177
+ print("Training........")
178
+ model.gradient_checkpointing_disable()
179
+ model.pretrained_model.config.use_cache = True
180
+ epochs = script_args.epochs
181
+ mean_scores = []
182
+ std_scores = []
183
+ save_data = {
184
+ 'kl_mean': [],
185
+ 'reward_mean': [],
186
+ 'reward_std': [],
187
+ 'text_sample':[],
188
+ 'batch_time':[],
189
+ 'total_time':[],
190
+ }
191
+ t_start = time.time()
192
+ for epoch in range(epochs):
193
+ pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
194
+ for i, batch in enumerate(ppo_trainer.dataloader):
195
+ # print(batch.keys())
196
+ t_epoch_start = time.time()
197
+ print('epoch {}, batch {}'.format(epoch, i))
198
+ query_tensors = batch["input_ids"]
199
+
200
+ model.gradient_checkpointing_disable()
201
+ model.pretrained_model.config.use_cache = True
202
+
203
+ with torch.no_grad():
204
+ response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
205
+
206
+ # 看ppo_trainer中的实现,pad token已经被去除干净了
207
+ full_responses_clean = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) # skip后,不用手动去除
208
+
209
+ # full_responses = tokenizer.batch_decode(response_tensors)
210
+ # full_responses_clean = []
211
+ # for _, response in enumerate(full_responses):
212
+ # response = response.strip('[PAD] ')
213
+ # response = response.strip('<unk>')
214
+ # temp_resp = response.strip('<s>').strip('</s>')
215
+ # temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
216
+ # temp_resp = temp_resp.split('\nHuman:')[0].strip()
217
+ # temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
218
+ # temp_resp = temp_resp.split('\nAssistant:')[0].strip()
219
+ # temp_resp = temp_resp.split('\n\n\n')[0].strip()
220
+ # temp_resp = temp_resp.split('###')[0].strip()
221
+ # full_responses_clean.append(temp_resp)
222
+
223
+ clean_texts = full_responses_clean
224
+ clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
225
+
226
+ # clean后的response map回去
227
+ lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
228
+ response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
229
+ batch['response'] = clean_texts
230
+
231
+ # Compute score
232
+ # texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
233
+ # # 从整个回复中分解prompt和response
234
+ # queries_responses = [
235
+ # (instructions.get_input(text), instructions.get_response(text))
236
+ # for text in texts_merge
237
+ # ]
238
+ # if hasattr(instructions, 'get_post'):
239
+ # rewards_list = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)
240
+ # else:
241
+ # rewards_list = reward_model.get_reward_model_scores(queries_responses)
242
+
243
+ qa_lsit = [(q, r) for q, r in zip(batch['reward_query'], batch['response'])]
244
+ rewards_list = reward_model.get_reward_model_scores(qa_lsit)
245
+ rewards = []
246
+ for j in range(len(qa_lsit)):
247
+ rewards.append(round(sum([preference[k] * rewards_list[k][j] for k in range(num_rewards)]), 2))
248
+ rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
249
+ print("iter {}, batch {}, mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
250
+
251
+ model.gradient_checkpointing_enable()
252
+ model.pretrained_model.config.use_cache = False
253
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
254
+ policy_kl = [stats["objective/kl"]]
255
+ ppo_trainer.log_stats(stats, batch, rewards)
256
+
257
+ all_rewards = accelerator.gather_for_metrics(rewards)
258
+ all_policy_kl = accelerator.gather_for_metrics(policy_kl)
259
+ if process_id == 0:
260
+ mean_scores.append(torch.mean(torch.tensor(all_rewards)).item())
261
+ std_scores.append(torch.std(torch.tensor(all_rewards)).item())
262
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
263
+ plt.plot(mean_scores)
264
+ plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores)- np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
265
+ plt.savefig(save_path)
266
+ t_epoch_end = time.time()
267
+ save_data['batch_time'].append(t_epoch_end - t_epoch_start)
268
+ save_data['total_time'].append(t_epoch_end - t_start)
269
+ save_data['kl_mean'].append(np.mean(all_policy_kl))
270
+ save_data['reward_mean'] = mean_scores
271
+ save_data['reward_std'] = std_scores
272
+ save_data['text_sample'].append(batch['query'][0] + batch['response'][0])
273
+ # save_data['text_sample'].append(texts_merge[0])
274
+ dataframe = pd.DataFrame(save_data)
275
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
276
+ print("iter {}, batch {}: log finish".format(epoch, i))
277
+
278
+ # wait for the main process
279
+ accelerator.wait_for_everyone()
280
+ pbar.update(1)
281
+
282
+ # save model
283
+ if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
284
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}-batch_{}'.format(epoch, i))
285
+ ppo_trainer.save_pretrained(save_path)
286
+ print("iter {}, batch {}: model saved".format(epoch, i))
287
+
288
+ # save model
289
+ if ppo_trainer.accelerator.is_main_process:
290
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}'.format(epoch))
291
+ ppo_trainer.save_pretrained(save_path)
292
+ print("iter {}, batch {}: model saved".format(epoch, i))
293
+
ppo/morlhf.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from accelerate import Accelerator
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import AutoTokenizer, HfArgumentParser
9
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
10
+ import numpy as np
11
+ import pandas as pd
12
+ from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
13
+ build_dataset, build_dataset_summary
14
+ from multi_reward_models import RewardModels
15
+ tqdm.pandas()
16
+ from peft import LoraConfig
17
+ import matplotlib.pyplot as plt
18
+
19
+ # define paths for two datasets
20
+ hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
21
+ summary_dataset_path = 'openai/summarize_from_feedback'
22
+
23
+ @dataclass
24
+ class ScriptArguments:
25
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
26
+ disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
27
+ save_directory: Optional[str] = field(default='./logs_morlhf/')
28
+ epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
29
+ learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
30
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
31
+ batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size64"})
32
+ gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
33
+ early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
34
+ target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"})
35
+ init_kl_coef: Optional[float] = field(default=0.2,metadata={"help": "0.05 Initial KL penalty coefficient (used for adaptive and linear control)"},)
36
+ max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
37
+ load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "loading model in 8 bit or bfloat16"})
38
+ preference: Optional[float] = field(default=0.5, metadata={"help": "the weight for reward 1"})
39
+ wandb_name: Optional[str] = field(default='morlhf_llamma2_klreg0.2', metadata={"help": "Name for this experiment"})
40
+ base_model_name: Optional[str] = field(default='./merged_sft_summary', metadata={'help':"the path to the sft model; need to merge if using lora"})
41
+ reward_names:Optional[str] = field(default='harmless,helpful,humor')
42
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
43
+
44
+ parser = HfArgumentParser(ScriptArguments)
45
+ script_args = parser.parse_args_into_dataclasses()[0]
46
+ exp_type = script_args.exp_type
47
+ preference = [round(script_args.preference, 1), round(1 - script_args.preference, 1)]
48
+ script_args.wandb_name = script_args.wandb_name + '_pref{}_{}'.format(preference[0], preference[1])
49
+
50
+ tokenier_name = script_args.base_model_name
51
+ base_model_name = script_args.base_model_name
52
+ print('base model: ', base_model_name)
53
+
54
+ if script_args.disable_wandb: # if you don't need the wandb log
55
+ os.environ['WANDB_DISABLED'] = 'true'
56
+
57
+ reward_names = [x.strip() for x in script_args.reward_names.split(',')]
58
+ num_rewards = len(reward_names)
59
+ print('number of rewards: {}'.format(num_rewards))
60
+ #######################################################################
61
+ if num_rewards == 3:
62
+ preference = [round(1 / num_rewards, 2) for _ in range(num_rewards)]
63
+ print('preference: {}'.format(preference))
64
+ #######################################################################
65
+ reward_path_tokenizer_dict = {
66
+ 'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
67
+ 'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
68
+ 'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
69
+ 'summary': ['Tristan/gpt2_reward_summarization'],
70
+ 'faithful':['CogComp/bart-faithful-summary-detector'],
71
+ 'humor': ['mohameddhiab/humor-no-humor'],
72
+ }
73
+ reward_model_path_list = []
74
+ rm_tokenizer_path_list = []
75
+ for name in reward_names:
76
+ if name not in reward_path_tokenizer_dict.keys():
77
+ raise NotImplementedError
78
+ reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
79
+ rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
80
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
81
+
82
+
83
+ config = PPOConfig(
84
+ model_name=base_model_name,
85
+ learning_rate=script_args.learning_rate,
86
+ log_with=script_args.log_with,
87
+ mini_batch_size=script_args.mini_batch_size,
88
+ batch_size=script_args.batch_size,
89
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
90
+ early_stopping=script_args.early_stopping,
91
+ target=script_args.target,
92
+ max_grad_norm=script_args.max_grad_norm,
93
+ optimize_cuda_cache=True,
94
+ init_kl_coef=script_args.init_kl_coef,
95
+ tracker_project_name='morlhf',
96
+ tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
97
+ )
98
+
99
+ accelerator = Accelerator()
100
+ process_id = Accelerator().local_process_index
101
+ gpu_id = process_id
102
+ print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
103
+
104
+
105
+ ############## load reward models
106
+ reward_model = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id)
107
+ rm_tokenizer = AutoTokenizer.from_pretrained(rm_tokenizer_path_list[0])
108
+
109
+ def collator(data):
110
+ return dict((key, [d[key] for d in data]) for key in data[0])
111
+
112
+ set_seed(8888)
113
+ current_device = Accelerator().local_process_index
114
+ print(current_device)
115
+
116
+ lora_config = LoraConfig(
117
+ r=64,
118
+ lora_alpha=128,
119
+ lora_dropout=0.05,
120
+ bias="none",
121
+ task_type="CAUSAL_LM",
122
+ )
123
+
124
+
125
+ tokenizer = load_main_tokenizer(tokenier_name)
126
+ if exp_type == 'assistant':
127
+ dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
128
+ instructions = Instructions()
129
+ else:
130
+ dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
131
+ instructions = Instructions_summary()
132
+ train_dataset = dataset.shuffle()
133
+ print(f"Size of the train set: {len(train_dataset)}")
134
+
135
+
136
+ if script_args.load_in_8bit:
137
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
138
+ base_model_name,
139
+ load_in_8bit=True,
140
+ peft_config=lora_config,
141
+ device_map=gpu_id,
142
+ )
143
+ else:
144
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
145
+ base_model_name,
146
+ torch_dtype=torch.bfloat16,
147
+ peft_config=lora_config,
148
+ device_map=gpu_id,
149
+ )
150
+
151
+ print_trainable_parameters(model)
152
+ model.pretrained_model.resize_token_embeddings(len(tokenizer))
153
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
154
+
155
+ ppo_trainer = PPOTrainer(
156
+ config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
157
+ )
158
+
159
+ generation_kwargs = {
160
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
161
+ "min_length": -1,
162
+ "top_k": 0.0,
163
+ "top_p": 1.0,
164
+ "do_sample": True,
165
+ "temperature": 0.7,
166
+ "pad_token_id": tokenizer.eos_token_id,
167
+ "begin_suppress_tokens": [tokenizer.eos_token_id] ,
168
+ }
169
+
170
+
171
+ print("Training........")
172
+ model.gradient_checkpointing_disable()
173
+ model.pretrained_model.config.use_cache = True
174
+ epochs = script_args.epochs
175
+ mean_scores = []
176
+ std_scores = []
177
+ save_data = {
178
+ 'kl_mean': [],
179
+ 'reward_mean': [],
180
+ 'reward_std': [],
181
+ 'text_sample':[],
182
+ 'batch_time':[],
183
+ 'total_time':[],
184
+ }
185
+ t_start = time.time()
186
+ for epoch in range(epochs):
187
+ pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
188
+ for i, batch in enumerate(ppo_trainer.dataloader):
189
+ t_epoch_start = time.time()
190
+ print('epoch {}, batch {}'.format(epoch, i))
191
+ query_tensors = batch["input_ids"]
192
+
193
+ model.gradient_checkpointing_disable()
194
+ model.pretrained_model.config.use_cache = True
195
+
196
+ with torch.no_grad():
197
+ response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
198
+
199
+ full_responses = tokenizer.batch_decode(response_tensors)
200
+ full_responses_clean = []
201
+ for _, response in enumerate(full_responses):
202
+ response = response.strip('[PAD] ')
203
+ response = response.strip('<unk>')
204
+ temp_resp = response.strip('<s>').strip('</s>')
205
+ temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
206
+ temp_resp = temp_resp.split('\nHuman:')[0].strip()
207
+ temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
208
+ temp_resp = temp_resp.split('\nAssistant:')[0].strip()
209
+ temp_resp = temp_resp.split('\n\n\n')[0].strip()
210
+ temp_resp = temp_resp.split('###')[0].strip()
211
+ full_responses_clean.append(temp_resp)
212
+
213
+ clean_texts = full_responses_clean
214
+ clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
215
+
216
+ lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
217
+ response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
218
+ batch['response'] = clean_texts
219
+
220
+ # Compute score
221
+ texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
222
+ queries_responses = [
223
+ (instructions.get_input(text), instructions.get_response(text))
224
+ for text in texts_merge
225
+ ]
226
+ if hasattr(instructions, 'get_post'):
227
+ rewards_list = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)
228
+ else:
229
+ rewards_list = reward_model.get_reward_model_scores(queries_responses)
230
+ rewards = []
231
+ for j in range(len(queries_responses)):
232
+ rewards.append(round(sum([preference[k] * rewards_list[k][j] for k in range(num_rewards)]), 2))
233
+ rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
234
+ print("iter {}, batch {}, mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
235
+
236
+ model.gradient_checkpointing_enable()
237
+ model.pretrained_model.config.use_cache = False
238
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
239
+ policy_kl = [stats["objective/kl"]]
240
+ ppo_trainer.log_stats(stats, batch, rewards)
241
+
242
+ all_rewards = accelerator.gather_for_metrics(rewards)
243
+ all_policy_kl = accelerator.gather_for_metrics(policy_kl)
244
+ if process_id == 0:
245
+ mean_scores.append(torch.mean(torch.tensor(all_rewards)).item())
246
+ std_scores.append(torch.std(torch.tensor(all_rewards)).item())
247
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
248
+ plt.plot(mean_scores)
249
+ plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores)- np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
250
+ plt.savefig(save_path)
251
+ t_epoch_end = time.time()
252
+ save_data['batch_time'].append(t_epoch_end - t_epoch_start)
253
+ save_data['total_time'].append(t_epoch_end - t_start)
254
+ save_data['kl_mean'].append(np.mean(all_policy_kl))
255
+ save_data['reward_mean'] = mean_scores
256
+ save_data['reward_std'] = std_scores
257
+ save_data['text_sample'].append(texts_merge[0])
258
+ dataframe = pd.DataFrame(save_data)
259
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
260
+ print("iter {}, batch {}: log finish".format(epoch, i))
261
+
262
+ # wait for the main process
263
+ accelerator.wait_for_everyone()
264
+ pbar.update(1)
265
+
266
+ # save model
267
+ if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
268
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}-batch_{}'.format(epoch, i))
269
+ ppo_trainer.save_pretrained(save_path)
270
+ print("iter {}, batch {}: model saved".format(epoch, i))
271
+
272
+ # save model
273
+ if ppo_trainer.accelerator.is_main_process:
274
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'batch_{}'.format(i))
275
+ ppo_trainer.save_pretrained(save_path)
276
+ print("iter {}, batch {}: model saved".format(epoch, i))
277
+
ppo/multi_reward_models.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from utils import load_reward_model, get_rewards
6
+
7
+ class RewardModels():
8
+ def __init__(self, reward_model_path_list, rm_tokenizer_path_list, gpu_id_list, reward_stats_path=None):
9
+ assert len(reward_model_path_list) == len(rm_tokenizer_path_list)
10
+ self.reward_model_path_list = reward_model_path_list
11
+ self.rm_tokenizer_path_list = rm_tokenizer_path_list
12
+ self.num_rewards = len(reward_model_path_list)
13
+ self.reward_stats = np.load(reward_stats_path) if reward_stats_path is not None else None
14
+ self.reward_models = []
15
+ self.rm_tokenizers = []
16
+ if type(gpu_id_list) != list:
17
+ gpu_id_list = [gpu_id_list, gpu_id_list, gpu_id_list]
18
+
19
+ print('Loading reward models .....')
20
+ for i in range(self.num_rewards):
21
+ self.reward_models.append(load_reward_model(self.reward_model_path_list[i], gpu_id_list[i]))
22
+ self.rm_tokenizers.append(AutoTokenizer.from_pretrained(self.rm_tokenizer_path_list[i]))
23
+
24
+
25
+ def get_reward_model_scores(self, queries_responses, summary_fun=None):
26
+ texts_for_rewards = []
27
+ for i in range(self.num_rewards):
28
+ if i >= 1 and self.rm_tokenizer_path_list[i] == self.rm_tokenizer_path_list[i-1]:
29
+ texts_for_rewards.append(texts_for_rewards[-1])
30
+ elif 'faithful' in self.reward_model_path_list[i]:
31
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
32
+ temp_encoded_texts = [self.rm_tokenizers[i](text=r, text_pair=summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
33
+ texts_for_rewards.append(temp_encoded_texts)
34
+ elif 'summary' in self.reward_model_path_list[i] or 'summarization' in self.reward_model_path_list[i]: # reverse prompt and response
35
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
36
+ temp_encoded_texts = [self.rm_tokenizers[i](r + " " + self.rm_tokenizers[i].bos_token + " " + summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
37
+ texts_for_rewards.append(temp_encoded_texts)
38
+ elif 'humor' in self.reward_model_path_list[i]: # use only response
39
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
40
+ temp_encoded_texts = [self.rm_tokenizers[i](r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
41
+ texts_for_rewards.append(temp_encoded_texts)
42
+ else:
43
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
44
+ temp_encoded_texts = [self.rm_tokenizers[i](q, r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
45
+ texts_for_rewards.append(temp_encoded_texts)
46
+
47
+ # normalize reward
48
+ rewards = []
49
+ for i in range(self.num_rewards):
50
+ if self.reward_stats is not None:
51
+ if type(self.reward_stats) == list or len(self.reward_stats) == 2 * self.num_rewards:
52
+ reward_mean_std = (self.reward_stats[2*i], self.reward_stats[2*i+1])
53
+ else:
54
+ reward_mean_std = self.reward_stats[i]
55
+ else:
56
+ reward_mean_std = None
57
+
58
+ if 'humor' in self.reward_model_path_list[i] or 'faithful' in self.reward_model_path_list[i]:
59
+ temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std, sub_position=1)
60
+ else:
61
+ temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std)
62
+ rewards.append(temp_reward)
63
+ return rewards
64
+
ppo/ppo.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+ from accelerate import Accelerator
5
+ import torch
6
+ from tqdm import tqdm
7
+ from transformers import HfArgumentParser
8
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
9
+ import numpy as np
10
+ import pandas as pd
11
+ from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
12
+ build_dataset, build_dataset_summary
13
+ from multi_reward_models import RewardModels
14
+ tqdm.pandas()
15
+ from peft import LoraConfig
16
+ import matplotlib.pyplot as plt
17
+ from transformers import LlamaForCausalLM
18
+ os.environ["WANDB_MODE"] = "offline"
19
+ # define paths for two datasets
20
+ hhrlhf_dataset_path = '../datasets/hh-rlhf'
21
+ # summary_dataset_path = 'openai/summarize_from_feedback'
22
+
23
+ class BlockWrapper(torch.nn.Module):
24
+ def __init__(self, block, vec=None):
25
+ super().__init__()
26
+ self.multiplier = 1.0
27
+ self.block = block
28
+ if vec is not None:
29
+ self.vec = torch.nn.Parameter(vec)
30
+ else:
31
+ # Zero Init
32
+ self.vec = torch.nn.Parameter(torch.zeros(4096, dtype=torch.bfloat16))
33
+
34
+ def forward(self, *args, **kwargs):
35
+ output = self.block(*args, **kwargs)
36
+ output = (output[0] + (self.multiplier * self.vec),) + output[1:]
37
+ return output
38
+
39
+ def set_multiplier(self, multiplier):
40
+ self.multiplier = multiplier
41
+
42
+ @dataclass
43
+ class ScriptArguments:
44
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
45
+ disable_wandb: Optional[str] = field(default=True, metadata={'help': 'Whether to disable wandb or not.'})
46
+ save_directory: Optional[str] = field(default='./ckpt/')
47
+ epochs: Optional[int] = field(default=3, metadata={'help': "Number of training epoches"})
48
+ learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
49
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
50
+ batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
51
+ load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
52
+ gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
53
+ early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
54
+ target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"}) # maybe 12
55
+ init_kl_coef: Optional[float] = field(default=0.2,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
56
+ max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
57
+ wandb_name: Optional[str] = field(default='new', metadata={"help": "Name for this experiment"})
58
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type: 'assistant" or 'summary'})
59
+ base_model_name: Optional[str] = field(default=None, metadata={'help':"the path to the sft model; need to merge if using lora"})
60
+ reward_name: Optional[str] = field(default='helpful')
61
+ layers: Optional[str] = field(default='15', metadata={"help": "the layers the steering vector extracted from: 12,13,14,15"})
62
+
63
+ parser = HfArgumentParser(ScriptArguments)
64
+ script_args = parser.parse_args_into_dataclasses()[0]
65
+ exp_type = script_args.exp_type
66
+ # Remember to use a merged sft model if using lora
67
+ base_model_name = script_args.base_model_name
68
+ tokenier_name = script_args.base_model_name
69
+ print('base model: ', base_model_name)
70
+
71
+ script_args.layers = [int(layer) for layer in script_args.layers.split(',')]
72
+
73
+ if script_args.disable_wandb: # if you don't need the wandb log
74
+ os.environ['WANDB_DISABLED'] = 'true'
75
+
76
+ reward_name = script_args.reward_name
77
+ # if reward_name == 'summary':
78
+ # reward_peft_path = 'Tristan/gpt2_reward_summarization'
79
+ # elif reward_name == 'faithful':
80
+ # reward_peft_path = 'CogComp/bart-faithful-summary-detector'
81
+ if reward_name == 'helpful':
82
+ reward_peft_path = '../reward_models/gpt2-large-helpful-reward_model'
83
+ elif reward_name == 'harmless':
84
+ reward_peft_path = '../reward_models/gpt2-large-harmless-reward_model'
85
+ # elif reward_name == 'deberta':
86
+ # reward_peft_path = 'OpenAssistant/reward-model-deberta-v3-large-v2'
87
+ elif reward_name == 'humor':
88
+ reward_peft_path = '../reward_models/humor-no-humor'
89
+ else:
90
+ raise NotImplementedError
91
+ rm_tokenizer_path = reward_peft_path
92
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
93
+
94
+
95
+ config = PPOConfig(
96
+ model_name=base_model_name,
97
+ learning_rate=script_args.learning_rate,
98
+ # log_with=script_args.log_with,
99
+ mini_batch_size=script_args.mini_batch_size,
100
+ batch_size=script_args.batch_size,
101
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
102
+ early_stopping=script_args.early_stopping,
103
+ target=script_args.target,
104
+ max_grad_norm=script_args.max_grad_norm,
105
+ optimize_cuda_cache=True,
106
+ init_kl_coef=script_args.init_kl_coef
107
+ # tracker_project_name='ppo',
108
+ # tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
109
+ )
110
+
111
+ accelerator = Accelerator()
112
+ process_id = Accelerator().local_process_index
113
+ gpu_id = process_id
114
+ print('process: {}'.format(process_id))
115
+ reward_model = RewardModels([reward_peft_path], [rm_tokenizer_path], gpu_id)
116
+ rm_tokenizer = reward_model.rm_tokenizers[0]
117
+
118
+
119
+ # set seed before initializing value head for deterministic eval
120
+ set_seed(8888)
121
+ current_device = Accelerator().local_process_index
122
+ print(current_device)
123
+
124
+ # lora_config = LoraConfig(
125
+ # r=64,
126
+ # lora_alpha=128,
127
+ # lora_dropout=0.05,
128
+ # bias="none",
129
+ # task_type="CAUSAL_LM",
130
+ # )
131
+
132
+ tokenizer = load_main_tokenizer(tokenier_name)
133
+ if exp_type == 'assistant':
134
+ dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
135
+ instructions = Instructions()
136
+ else:
137
+ raise NotImplementedError
138
+ # dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
139
+ # instructions = Instructions_summary()
140
+ train_dataset = dataset.shuffle()
141
+ print(f"Size of the train set: {len(train_dataset)}.")
142
+
143
+ if script_args.load_in_8bit:
144
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
145
+ base_model_name,
146
+ load_in_8bit=True,
147
+ # peft_config=lora_config,
148
+ device_map=gpu_id,
149
+ )
150
+ else:
151
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
152
+ base_model_name,
153
+ torch_dtype=torch.bfloat16,
154
+ # peft_config=lora_config,
155
+ device_map=gpu_id,
156
+ )
157
+
158
+ ###################
159
+ for layer in script_args.layers:
160
+ model.pretrained_model.model.layers[layer] = BlockWrapper(model.pretrained_model.model.layers[layer])
161
+ model.pretrained_model.config.use_cache = False
162
+ for name, param in model.pretrained_model.named_parameters():
163
+ param.requires_grad = False
164
+ for layer in script_args.layers:
165
+ model.pretrained_model.model.layers[layer].vec.requires_grad = True
166
+ ###################
167
+
168
+ print_trainable_parameters(model)
169
+ print('trainable params names:')
170
+ for name, param in model.named_parameters():
171
+ if param.requires_grad:
172
+ print(name)
173
+
174
+ model.pretrained_model.resize_token_embeddings(len(tokenizer))
175
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
176
+ def collator(data):
177
+ return dict((key, [d[key] for d in data]) for key in data[0])
178
+
179
+ ppo_trainer = PPOTrainer(
180
+ config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
181
+ )
182
+
183
+ generation_kwargs = {
184
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
185
+ 'min_length': -1,
186
+ "top_k": 0.0,
187
+ "top_p": 1.0,
188
+ "do_sample": True,
189
+ "temperature": 0.7,
190
+ "pad_token_id": tokenizer.eos_token_id,
191
+ "begin_suppress_tokens": [tokenizer.eos_token_id],
192
+ }
193
+
194
+ print("Training........")
195
+ model.gradient_checkpointing_disable()
196
+ model.pretrained_model.config.use_cache = True
197
+
198
+ epochs = script_args.epochs
199
+ mean_scores = []
200
+ std_scores = []
201
+ save_data = {
202
+ 'kl_mean': [],
203
+ 'kl_std': [],
204
+ 'reward_mean': [],
205
+ 'reward_std': [],
206
+ 'text_sample':[],
207
+ }
208
+ for epoch in range(epochs):
209
+ pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
210
+ for i, batch in enumerate(ppo_trainer.dataloader):
211
+ print('epoch {}, batch {}'.format(epoch, i))
212
+ query_tensors = batch["input_ids"]
213
+
214
+ model.gradient_checkpointing_disable()
215
+ model.pretrained_model.config.use_cache = True
216
+
217
+ with torch.no_grad():
218
+ response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
219
+
220
+ full_responses = tokenizer.batch_decode(response_tensors)
221
+ full_responses_clean = []
222
+ for _, response in enumerate(full_responses):
223
+ response = response.strip('[PAD] ')
224
+ response = response.strip('<unk>')
225
+ temp_resp = response.strip('<s>').strip('</s>')
226
+ temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
227
+ temp_resp = temp_resp.split('\nHuman:')[0].strip()
228
+ temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
229
+ temp_resp = temp_resp.split('\nAssistant:')[0].strip()
230
+ temp_resp = temp_resp.split('###')[0].strip()
231
+ temp_resp = temp_resp.split('\n\n\n')[0].strip()
232
+ full_responses_clean.append(temp_resp)
233
+
234
+ clean_texts = full_responses_clean
235
+ clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
236
+
237
+ lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
238
+ print(lengths)
239
+ response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
240
+ batch['response'] = clean_texts
241
+
242
+ # Compute score
243
+ texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
244
+ queries_responses = [
245
+ (instructions.get_input(text), instructions.get_response(text))
246
+ for text in texts_merge
247
+ ]
248
+ if hasattr(instructions, 'get_post'):
249
+ rewards = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)[0]
250
+ else:
251
+ rewards = reward_model.get_reward_model_scores(queries_responses)[0]
252
+ rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
253
+ print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
254
+
255
+ model.gradient_checkpointing_enable()
256
+ model.pretrained_model.config.use_cache = False
257
+ ppo_trainer.config.batch_size = len(query_tensors)
258
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
259
+ ppo_trainer.log_stats(stats, batch, rewards)
260
+ policy_kl = [stats["objective/kl"]]
261
+
262
+ all_rewards = accelerator.gather_for_metrics(rewards)
263
+ all_policy_kl = accelerator.gather_for_metrics(policy_kl)
264
+ if ppo_trainer.accelerator.is_main_process:
265
+ mean_scores.append(torch.mean(torch.tensor(rewards)).item())
266
+ std_scores.append(torch.std(torch.tensor(rewards)).item())
267
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
268
+ plt.plot(mean_scores)
269
+ plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores) - np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
270
+ plt.savefig(save_path)
271
+
272
+ save_data['kl_mean'].append(np.mean(all_policy_kl))
273
+ save_data['kl_std'].append(np.std(all_policy_kl))
274
+ save_data['reward_mean'] = mean_scores
275
+ save_data['reward_std'] = std_scores
276
+ save_data['text_sample'].append(texts_merge[0])
277
+ dataframe = pd.DataFrame(save_data)
278
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
279
+ print("iter {}, batch {}: log finish".format(epoch, i))
280
+
281
+ # wait for the main process
282
+ accelerator.wait_for_everyone()
283
+ pbar.update(1)
284
+
285
+ # save model
286
+ if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
287
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
288
+ # ppo_trainer.save_pretrained(save_path)
289
+
290
+ for layer in script_args.layers:
291
+ # steer_vec = ppo_trainer.model.pretrained_model.model.layers[layer].vec.detach().cpu()
292
+ steer_vec = model.pretrained_model.model.layers[layer].vec.detach().cpu()
293
+ print(f'Steer vec of layer {layer} at step {i}: ', steer_vec[:10], steer_vec.dtype)
294
+ torch.save(
295
+ steer_vec,
296
+ f"{save_path}/vec_layer{layer}-step{i}.pt",
297
+ )
298
+ print("iter {}, batch {}: model saved".format(epoch, i))
299
+
300
+ # save model
301
+ if ppo_trainer.accelerator.is_main_process:
302
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
303
+ # ppo_trainer.save_pretrained(save_path)
304
+ for layer in script_args.layers:
305
+ # steer_vec = ppo_trainer.model.pretrained_model.model.layers[layer].vec.detach().cpu()
306
+ steer_vec = model.pretrained_model.model.layers[layer].vec.detach().cpu()
307
+ print(f'Steer vec of layer {layer} at step {i}: ', steer_vec[:10], steer_vec.dtype)
308
+ torch.save(
309
+ steer_vec,
310
+ f"{save_path}/vec_layer{layer}-step{i}.pt",
311
+ )
312
+ print("iter {}, batch {}: model saved".format(epoch, i))
313
+
ppo/ppo_reft.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+ from accelerate import Accelerator
5
+ import torch
6
+ import math
7
+ from tqdm import tqdm
8
+ from transformers import HfArgumentParser
9
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
10
+ import numpy as np
11
+ import pandas as pd
12
+ from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
13
+ build_dataset, build_dataset_summary
14
+ from multi_reward_models import RewardModels
15
+ tqdm.pandas()
16
+ from peft import LoraConfig
17
+ import matplotlib.pyplot as plt
18
+ os.environ["WANDB_MODE"] = "offline"
19
+ # define paths for two datasets
20
+ hhrlhf_dataset_path = '../datasets/hh-rlhf'
21
+ # summary_dataset_path = 'openai/summarize_from_feedback'
22
+
23
+ from transformers.activations import ACT2FN
24
+ class BlockWrapper(torch.nn.Module):
25
+ def __init__(self, block, rank=64):
26
+ super().__init__()
27
+ self.multiplier = 1.0
28
+ self.block = block
29
+ self.learned_source = torch.nn.Linear(4096, rank).to(torch.bfloat16)
30
+ torch.nn.init.kaiming_uniform_(self.learned_source.weight, a=math.sqrt(5))
31
+ self.proj_layer = torch.nn.Linear(rank, 4096, bias=False).to(torch.bfloat16)
32
+ torch.nn.init.zeros_(self.proj_layer.weight)
33
+ # self.act_fn = ACT2FN["linear"] # TODO: revise
34
+
35
+ def forward(self, *args, **kwargs):
36
+ output = self.block(*args, **kwargs)
37
+ output = (output[0] + (self.multiplier * self.proj_layer(self.learned_source(output[0]))),) + output[1:]
38
+ return output
39
+
40
+ def set_multiplier(self, multiplier):
41
+ self.multiplier = multiplier
42
+
43
+ @dataclass
44
+ class ScriptArguments:
45
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
46
+ disable_wandb: Optional[str] = field(default=True, metadata={'help': 'Whether to disable wandb or not.'})
47
+ save_directory: Optional[str] = field(default='./ckpt/')
48
+ epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
49
+ learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
50
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
51
+ batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
52
+ load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
53
+ gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
54
+ early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
55
+ target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"}) # maybe 12
56
+ init_kl_coef: Optional[float] = field(default=0.8,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
57
+ max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
58
+ wandb_name: Optional[str] = field(default='new', metadata={"help": "Name for this experiment"})
59
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type: 'assistant" or 'summary'})
60
+ base_model_name: Optional[str] = field(default=None, metadata={'help':"the path to the sft model; need to merge if using lora"})
61
+ reward_name: Optional[str] = field(default='helpful')
62
+ layers: Optional[str] = field(default='15', metadata={"help": "the layers the steering vector extracted from: 12,13,14,15"})
63
+
64
+ parser = HfArgumentParser(ScriptArguments)
65
+ script_args = parser.parse_args_into_dataclasses()[0]
66
+ exp_type = script_args.exp_type
67
+ # Remember to use a merged sft model if using lora
68
+ base_model_name = script_args.base_model_name
69
+ tokenier_name = script_args.base_model_name
70
+ print('base model: ', base_model_name)
71
+
72
+ script_args.layers = [int(layer) for layer in script_args.layers.split(',')]
73
+
74
+ if script_args.disable_wandb: # if you don't need the wandb log
75
+ os.environ['WANDB_DISABLED'] = 'true'
76
+
77
+ reward_name = script_args.reward_name
78
+ # if reward_name == 'summary':
79
+ # reward_peft_path = 'Tristan/gpt2_reward_summarization'
80
+ # elif reward_name == 'faithful':
81
+ # reward_peft_path = 'CogComp/bart-faithful-summary-detector'
82
+ if reward_name == 'helpful':
83
+ reward_peft_path = '../reward_models/gpt2-large-helpful-reward_model'
84
+ elif reward_name == 'harmless':
85
+ reward_peft_path = '../reward_models/gpt2-large-harmless-reward_model'
86
+ # elif reward_name == 'deberta':
87
+ # reward_peft_path = 'OpenAssistant/reward-model-deberta-v3-large-v2'
88
+ elif reward_name == 'humor':
89
+ reward_peft_path = '../reward_models/humor-no-humor'
90
+ else:
91
+ raise NotImplementedError
92
+ rm_tokenizer_path = reward_peft_path
93
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
94
+
95
+
96
+ config = PPOConfig(
97
+ model_name=base_model_name,
98
+ learning_rate=script_args.learning_rate,
99
+ # log_with=script_args.log_with,
100
+ mini_batch_size=script_args.mini_batch_size,
101
+ batch_size=script_args.batch_size,
102
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
103
+ early_stopping=script_args.early_stopping,
104
+ target=script_args.target,
105
+ max_grad_norm=script_args.max_grad_norm,
106
+ optimize_cuda_cache=True,
107
+ init_kl_coef=script_args.init_kl_coef
108
+ # tracker_project_name='ppo',
109
+ # tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
110
+ )
111
+
112
+ accelerator = Accelerator()
113
+ process_id = Accelerator().local_process_index
114
+ gpu_id = process_id
115
+ print('process: {}'.format(process_id))
116
+ reward_model = RewardModels([reward_peft_path], [rm_tokenizer_path], gpu_id)
117
+ rm_tokenizer = reward_model.rm_tokenizers[0]
118
+
119
+
120
+ # set seed before initializing value head for deterministic eval
121
+ set_seed(8888)
122
+ current_device = Accelerator().local_process_index
123
+ print(current_device)
124
+
125
+ # lora_config = LoraConfig(
126
+ # r=64,
127
+ # lora_alpha=128,
128
+ # lora_dropout=0.05,
129
+ # bias="none",
130
+ # task_type="CAUSAL_LM",
131
+ # )
132
+
133
+ tokenizer = load_main_tokenizer(tokenier_name)
134
+ if exp_type == 'assistant':
135
+ dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
136
+ instructions = Instructions()
137
+ else:
138
+ raise NotImplementedError
139
+ # dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
140
+ # instructions = Instructions_summary()
141
+ train_dataset = dataset.shuffle()
142
+ print(f"Size of the train set: {len(train_dataset)}.")
143
+
144
+ if script_args.load_in_8bit:
145
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
146
+ base_model_name,
147
+ load_in_8bit=True,
148
+ # peft_config=lora_config,
149
+ device_map=gpu_id,
150
+ )
151
+ else:
152
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
153
+ base_model_name,
154
+ torch_dtype=torch.bfloat16,
155
+ # peft_config=lora_config,
156
+ device_map=gpu_id,
157
+ )
158
+
159
+ ###################
160
+ for layer in script_args.layers:
161
+ model.pretrained_model.model.layers[layer] = BlockWrapper(model.pretrained_model.model.layers[layer])
162
+ model.pretrained_model.config.use_cache = False
163
+ for name, param in model.pretrained_model.named_parameters():
164
+ param.requires_grad = False
165
+ for layer in script_args.layers:
166
+ for name, param in model.pretrained_model.model.layers[layer].learned_source.named_parameters():
167
+ param.requires_grad = True
168
+ for name, param in model.pretrained_model.model.layers[layer].proj_layer.named_parameters():
169
+ param.requires_grad = True
170
+ ###################
171
+
172
+ print_trainable_parameters(model)
173
+ print('trainable params names:')
174
+ for name, param in model.named_parameters():
175
+ if param.requires_grad:
176
+ print(name)
177
+
178
+ model.pretrained_model.resize_token_embeddings(len(tokenizer))
179
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
180
+ def collator(data):
181
+ return dict((key, [d[key] for d in data]) for key in data[0])
182
+
183
+ ppo_trainer = PPOTrainer(
184
+ config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
185
+ )
186
+
187
+ # print(ppo_trainer.model)
188
+ # print(ppo_trainer.model.module.pretrained_model)
189
+ # exit()
190
+
191
+ generation_kwargs = {
192
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
193
+ 'min_length': -1,
194
+ "top_k": 0.0,
195
+ "top_p": 1.0,
196
+ "do_sample": True,
197
+ "temperature": 0.7,
198
+ "pad_token_id": tokenizer.eos_token_id,
199
+ "begin_suppress_tokens": [tokenizer.eos_token_id],
200
+ }
201
+
202
+ print("Training........")
203
+ model.gradient_checkpointing_disable()
204
+ model.pretrained_model.config.use_cache = True
205
+
206
+ epochs = script_args.epochs
207
+ mean_scores = []
208
+ std_scores = []
209
+ save_data = {
210
+ 'kl_mean': [],
211
+ 'kl_std': [],
212
+ 'reward_mean': [],
213
+ 'reward_std': [],
214
+ 'text_sample':[],
215
+ }
216
+ for epoch in range(epochs):
217
+ pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
218
+ for i, batch in enumerate(ppo_trainer.dataloader):
219
+ print('epoch {}, batch {}'.format(epoch, i))
220
+ query_tensors = batch["input_ids"]
221
+
222
+ model.gradient_checkpointing_disable()
223
+ model.pretrained_model.config.use_cache = True
224
+
225
+ with torch.no_grad():
226
+ response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
227
+
228
+ full_responses = tokenizer.batch_decode(response_tensors)
229
+ full_responses_clean = []
230
+ for _, response in enumerate(full_responses):
231
+ response = response.strip('[PAD] ')
232
+ response = response.strip('<unk>')
233
+ temp_resp = response.strip('<s>').strip('</s>')
234
+ temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
235
+ temp_resp = temp_resp.split('\nHuman:')[0].strip()
236
+ temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
237
+ temp_resp = temp_resp.split('\nAssistant:')[0].strip()
238
+ temp_resp = temp_resp.split('###')[0].strip()
239
+ temp_resp = temp_resp.split('\n\n\n')[0].strip()
240
+ full_responses_clean.append(temp_resp)
241
+
242
+ clean_texts = full_responses_clean
243
+ clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
244
+
245
+ lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
246
+ print(lengths)
247
+ response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
248
+ batch['response'] = clean_texts
249
+
250
+ # Compute score
251
+ texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
252
+ queries_responses = [
253
+ (instructions.get_input(text), instructions.get_response(text))
254
+ for text in texts_merge
255
+ ]
256
+ if hasattr(instructions, 'get_post'):
257
+ rewards = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)[0]
258
+ else:
259
+ rewards = reward_model.get_reward_model_scores(queries_responses)[0]
260
+ rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
261
+ print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
262
+
263
+ model.gradient_checkpointing_enable()
264
+ model.pretrained_model.config.use_cache = False
265
+ ppo_trainer.config.batch_size = len(query_tensors)
266
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
267
+ ppo_trainer.log_stats(stats, batch, rewards)
268
+ policy_kl = [stats["objective/kl"]]
269
+
270
+ all_rewards = accelerator.gather_for_metrics(rewards)
271
+ all_policy_kl = accelerator.gather_for_metrics(policy_kl)
272
+ if ppo_trainer.accelerator.is_main_process:
273
+ mean_scores.append(torch.mean(torch.tensor(rewards)).item())
274
+ std_scores.append(torch.std(torch.tensor(rewards)).item())
275
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
276
+ plt.plot(mean_scores)
277
+ plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores) - np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
278
+ plt.savefig(save_path)
279
+
280
+ save_data['kl_mean'].append(np.mean(all_policy_kl))
281
+ save_data['kl_std'].append(np.std(all_policy_kl))
282
+ save_data['reward_mean'] = mean_scores
283
+ save_data['reward_std'] = std_scores
284
+ save_data['text_sample'].append(texts_merge[0])
285
+ dataframe = pd.DataFrame(save_data)
286
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
287
+ print("iter {}, batch {}: log finish".format(epoch, i))
288
+
289
+ # wait for the main process
290
+ accelerator.wait_for_everyone()
291
+ pbar.update(1)
292
+
293
+ # save model
294
+ if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
295
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
296
+ # ppo_trainer.save_pretrained(save_path)
297
+
298
+ for layer in script_args.layers:
299
+ print(f'learned_source layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source.weight[:10])
300
+ torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source, f"{save_path}/learned_source-{layer}-step{i}.pt")
301
+ torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_layer, f"{save_path}/proj_layer-{layer}-step{i}.pt")
302
+ print("iter {}, batch {}: model saved".format(epoch, i))
303
+
304
+ # save model
305
+ if ppo_trainer.accelerator.is_main_process:
306
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
307
+ # ppo_trainer.save_pretrained(save_path)
308
+ for layer in script_args.layers:
309
+ print(f'learned_source layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source.weight[:10])
310
+ torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].learned_source, f"{save_path}/learned_source-{layer}-step{i}.pt")
311
+ torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_layer, f"{save_path}/proj_layer-{layer}-step{i}.pt")
312
+ print("iter {}, batch {}: model saved".format(epoch, i))
313
+
ppo/ppo_reft_fine_grained.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+ from accelerate import Accelerator
5
+ import torch
6
+ import math
7
+ from tqdm import tqdm
8
+ from transformers import HfArgumentParser
9
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
10
+ import numpy as np
11
+ import pandas as pd
12
+ from utils import print_trainable_parameters, load_main_tokenizer, Instructions, Instructions_summary, \
13
+ build_dataset, build_dataset_summary
14
+ from multi_reward_models import RewardModels
15
+ tqdm.pandas()
16
+ from peft import LoraConfig
17
+ import matplotlib.pyplot as plt
18
+ os.environ["WANDB_MODE"] = "offline"
19
+ # define paths for two datasets
20
+ hhrlhf_dataset_path = '../datasets/hh-rlhf'
21
+ # summary_dataset_path = 'openai/summarize_from_feedback'
22
+
23
+ from transformers.activations import ACT2FN
24
+ class BlockWrapper(torch.nn.Module):
25
+ def __init__(self, block, dim=4096, rank=64):
26
+ super().__init__()
27
+ self.multiplier = 1.0
28
+ self.block = block
29
+
30
+ self.proj_A = torch.nn.Linear(dim, rank).to(torch.bfloat16)
31
+ torch.nn.init.kaiming_uniform_(self.proj_A.weight, a=math.sqrt(5))
32
+ for name, param in self.proj_A.named_parameters():
33
+ param.requires_grad = True
34
+
35
+ self.proj_B = torch.nn.Linear(rank, dim, bias=False).to(torch.bfloat16)
36
+ torch.nn.init.zeros_(self.proj_B.weight)
37
+ for name, param in self.proj_B.named_parameters():
38
+ param.requires_grad = True
39
+ # self.act_fn = ACT2FN["linear"] # TODO: revise
40
+
41
+ def forward(self, *args, **kwargs):
42
+ original_output = self.block(*args, **kwargs)
43
+ intervened_output = original_output + (self.multiplier * self.proj_B(self.proj_A(original_output)))
44
+ return intervened_output
45
+
46
+ def set_multiplier(self, multiplier):
47
+ self.multiplier = multiplier
48
+
49
+ @dataclass
50
+ class ScriptArguments:
51
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
52
+ disable_wandb: Optional[str] = field(default=True, metadata={'help': 'Whether to disable wandb or not.'})
53
+ save_directory: Optional[str] = field(default='./ckpt/')
54
+ epochs: Optional[int] = field(default=1, metadata={'help': "Number of training epoches"})
55
+ learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
56
+ mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
57
+ batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
58
+ load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "loading model in 8 bit or bfloat16"})
59
+ gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
60
+ early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
61
+ target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"}) # maybe 12
62
+ init_kl_coef: Optional[float] = field(default=0.8,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
63
+ max_grad_norm: Optional[float] = field(default=0.5, metadata={"help": "Maximum gradient norm for gradient clipping"})
64
+ wandb_name: Optional[str] = field(default='new', metadata={"help": "Name for this experiment"})
65
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type: 'assistant" or 'summary'})
66
+ base_model_name: Optional[str] = field(default=None, metadata={'help':"the path to the sft model; need to merge if using lora"})
67
+ reward_name: Optional[str] = field(default='helpful')
68
+ layers: Optional[str] = field(default='15', metadata={"help": "the layers the steering vector extracted from: 12,13,14,15"})
69
+
70
+ parser = HfArgumentParser(ScriptArguments)
71
+ script_args = parser.parse_args_into_dataclasses()[0]
72
+ exp_type = script_args.exp_type
73
+ # Remember to use a merged sft model if using lora
74
+ base_model_name = script_args.base_model_name
75
+ tokenier_name = script_args.base_model_name
76
+ print('base model: ', base_model_name)
77
+
78
+ script_args.layers = [int(layer) for layer in script_args.layers.split(',')]
79
+
80
+ if script_args.disable_wandb: # if you don't need the wandb log
81
+ os.environ['WANDB_DISABLED'] = 'true'
82
+
83
+ reward_name = script_args.reward_name
84
+ # if reward_name == 'summary':
85
+ # reward_peft_path = 'Tristan/gpt2_reward_summarization'
86
+ # elif reward_name == 'faithful':
87
+ # reward_peft_path = 'CogComp/bart-faithful-summary-detector'
88
+ if reward_name == 'helpful':
89
+ reward_peft_path = '../reward_models/gpt2-large-helpful-reward_model'
90
+ elif reward_name == 'harmless':
91
+ reward_peft_path = '../reward_models/gpt2-large-harmless-reward_model'
92
+ # elif reward_name == 'deberta':
93
+ # reward_peft_path = 'OpenAssistant/reward-model-deberta-v3-large-v2'
94
+ elif reward_name == 'humor':
95
+ reward_peft_path = '../reward_models/humor-no-humor'
96
+ else:
97
+ raise NotImplementedError
98
+ rm_tokenizer_path = reward_peft_path
99
+ os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)
100
+
101
+
102
+ config = PPOConfig(
103
+ model_name=base_model_name,
104
+ learning_rate=script_args.learning_rate,
105
+ # log_with=script_args.log_with,
106
+ mini_batch_size=script_args.mini_batch_size,
107
+ batch_size=script_args.batch_size,
108
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
109
+ early_stopping=script_args.early_stopping,
110
+ target=script_args.target,
111
+ max_grad_norm=script_args.max_grad_norm,
112
+ optimize_cuda_cache=True,
113
+ init_kl_coef=script_args.init_kl_coef
114
+ # tracker_project_name='ppo',
115
+ # tracker_kwargs={"wandb":{"name":script_args.wandb_name}},
116
+ )
117
+
118
+ accelerator = Accelerator()
119
+ process_id = Accelerator().local_process_index
120
+ gpu_id = process_id
121
+ print('process: {}'.format(process_id))
122
+ reward_model = RewardModels([reward_peft_path], [rm_tokenizer_path], gpu_id)
123
+ rm_tokenizer = reward_model.rm_tokenizers[0]
124
+
125
+
126
+ # set seed before initializing value head for deterministic eval
127
+ set_seed(8888)
128
+ current_device = Accelerator().local_process_index
129
+ print(current_device)
130
+
131
+ # lora_config = LoraConfig(
132
+ # r=64,
133
+ # lora_alpha=128,
134
+ # lora_dropout=0.05,
135
+ # bias="none",
136
+ # task_type="CAUSAL_LM",
137
+ # )
138
+
139
+ tokenizer = load_main_tokenizer(tokenier_name)
140
+ if exp_type == 'assistant':
141
+ dataset = build_dataset(hhrlhf_dataset_path, tokenizer, rm_tokenizer, split='train')
142
+ instructions = Instructions()
143
+ else:
144
+ raise NotImplementedError
145
+ # dataset = build_dataset_summary(summary_dataset_path, tokenizer, rm_tokenizer, split='train')
146
+ # instructions = Instructions_summary()
147
+ train_dataset = dataset.shuffle()
148
+ print(f"Size of the train set: {len(train_dataset)}.")
149
+
150
+ if script_args.load_in_8bit:
151
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
152
+ base_model_name,
153
+ load_in_8bit=True,
154
+ # peft_config=lora_config,
155
+ device_map=gpu_id,
156
+ )
157
+ else:
158
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
159
+ base_model_name,
160
+ torch_dtype=torch.bfloat16,
161
+ # peft_config=lora_config,
162
+ device_map=gpu_id,
163
+ )
164
+
165
+ ###################
166
+ model.pretrained_model.config.use_cache = False
167
+ for name, param in model.pretrained_model.named_parameters():
168
+ param.requires_grad = False
169
+
170
+ for layer in script_args.layers:
171
+ # self_attn module
172
+ num_heads = model.pretrained_model.model.layers[layer].self_attn.num_heads
173
+ head_dim = model.pretrained_model.model.layers[layer].self_attn.head_dim
174
+ num_key_value_heads = model.pretrained_model.model.layers[layer].self_attn.num_key_value_heads
175
+ hidden_size = model.pretrained_model.model.layers[layer].self_attn.hidden_size
176
+ model.pretrained_model.model.layers[layer].self_attn.q_proj =\
177
+ BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.q_proj, num_heads*head_dim)
178
+ model.pretrained_model.model.layers[layer].self_attn.k_proj =\
179
+ BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.k_proj, num_key_value_heads*head_dim)
180
+ model.pretrained_model.model.layers[layer].self_attn.v_proj =\
181
+ BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.v_proj, num_key_value_heads*head_dim)
182
+ model.pretrained_model.model.layers[layer].self_attn.o_proj =\
183
+ BlockWrapper(model.pretrained_model.model.layers[layer].self_attn.o_proj, hidden_size)
184
+
185
+ # mlp module
186
+ model.pretrained_model.model.layers[layer].mlp.up_proj =\
187
+ BlockWrapper(model.pretrained_model.model.layers[layer].mlp.up_proj, model.pretrained_model.model.layers[layer].mlp.intermediate_size)
188
+ model.pretrained_model.model.layers[layer].mlp.down_proj =\
189
+ BlockWrapper(model.pretrained_model.model.layers[layer].mlp.down_proj, model.pretrained_model.model.layers[layer].mlp.hidden_size)
190
+
191
+
192
+ # for layer in script_args.layers:
193
+ # for name, param in model.pretrained_model.model.layers[layer].proj_A.named_parameters():
194
+ # param.requires_grad = True
195
+ # for name, param in model.pretrained_model.model.layers[layer].proj_B.named_parameters():
196
+ # param.requires_grad = True
197
+ ###################
198
+
199
+ print_trainable_parameters(model)
200
+ print('trainable params names:')
201
+ for name, param in model.named_parameters():
202
+ if param.requires_grad:
203
+ print(name)
204
+
205
+ model.pretrained_model.resize_token_embeddings(len(tokenizer))
206
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
207
+ def collator(data):
208
+ return dict((key, [d[key] for d in data]) for key in data[0])
209
+
210
+ ppo_trainer = PPOTrainer(
211
+ config, model, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
212
+ )
213
+
214
+ generation_kwargs = {
215
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
216
+ 'min_length': -1,
217
+ "top_k": 0.0,
218
+ "top_p": 1.0,
219
+ "do_sample": True,
220
+ "temperature": 0.7,
221
+ "pad_token_id": tokenizer.eos_token_id,
222
+ "begin_suppress_tokens": [tokenizer.eos_token_id],
223
+ }
224
+
225
+ print("Training........")
226
+ model.gradient_checkpointing_disable()
227
+ model.pretrained_model.config.use_cache = True
228
+
229
+ epochs = script_args.epochs
230
+ mean_scores = []
231
+ std_scores = []
232
+ save_data = {
233
+ 'kl_mean': [],
234
+ 'kl_std': [],
235
+ 'reward_mean': [],
236
+ 'reward_std': [],
237
+ 'text_sample':[],
238
+ }
239
+ for epoch in range(epochs):
240
+ pbar = tqdm(total=len(train_dataset) // script_args.batch_size // accelerator.num_processes)
241
+ for i, batch in enumerate(ppo_trainer.dataloader):
242
+ print('epoch {}, batch {}'.format(epoch, i))
243
+ query_tensors = batch["input_ids"]
244
+
245
+ model.gradient_checkpointing_disable()
246
+ model.pretrained_model.config.use_cache = True
247
+
248
+ with torch.no_grad():
249
+ response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
250
+
251
+ full_responses = tokenizer.batch_decode(response_tensors)
252
+ full_responses_clean = []
253
+ for _, response in enumerate(full_responses):
254
+ response = response.strip('[PAD] ')
255
+ response = response.strip('<unk>')
256
+ temp_resp = response.strip('<s>').strip('</s>')
257
+ temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
258
+ temp_resp = temp_resp.split('\nHuman:')[0].strip()
259
+ temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
260
+ temp_resp = temp_resp.split('\nAssistant:')[0].strip()
261
+ temp_resp = temp_resp.split('###')[0].strip()
262
+ temp_resp = temp_resp.split('\n\n\n')[0].strip()
263
+ full_responses_clean.append(temp_resp)
264
+
265
+ clean_texts = full_responses_clean
266
+ clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
267
+
268
+ lengths = [len(clean_response_tensors[j]) for j in range(len(clean_response_tensors))]
269
+ print(lengths)
270
+ response_tensors = [response_tensors[j][:np.max([lengths[j], 2])] for j in range(len(response_tensors))]
271
+ batch['response'] = clean_texts
272
+
273
+ # Compute score
274
+ texts_merge = [q + r for q, r in zip(batch['query'], batch['response'])]
275
+ queries_responses = [
276
+ (instructions.get_input(text), instructions.get_response(text))
277
+ for text in texts_merge
278
+ ]
279
+ if hasattr(instructions, 'get_post'):
280
+ rewards = reward_model.get_reward_model_scores(queries_responses, instructions.get_post)[0]
281
+ else:
282
+ rewards = reward_model.get_reward_model_scores(queries_responses)[0]
283
+ rewards_tensor = [torch.tensor(r).to(gpu_id) for r in rewards]
284
+ print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
285
+
286
+ model.gradient_checkpointing_enable()
287
+ model.pretrained_model.config.use_cache = False
288
+ ppo_trainer.config.batch_size = len(query_tensors)
289
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards_tensor)
290
+ ppo_trainer.log_stats(stats, batch, rewards)
291
+ policy_kl = [stats["objective/kl"]]
292
+
293
+ all_rewards = accelerator.gather_for_metrics(rewards)
294
+ all_policy_kl = accelerator.gather_for_metrics(policy_kl)
295
+ if ppo_trainer.accelerator.is_main_process:
296
+ mean_scores.append(torch.mean(torch.tensor(rewards)).item())
297
+ std_scores.append(torch.std(torch.tensor(rewards)).item())
298
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'scores.png')
299
+ plt.plot(mean_scores)
300
+ plt.fill_between(np.arange(len(mean_scores)), np.array(mean_scores) - np.array(std_scores), np.array(mean_scores) + np.array(std_scores), alpha=0.5)
301
+ plt.savefig(save_path)
302
+
303
+ save_data['kl_mean'].append(np.mean(all_policy_kl))
304
+ save_data['kl_std'].append(np.std(all_policy_kl))
305
+ save_data['reward_mean'] = mean_scores
306
+ save_data['reward_std'] = std_scores
307
+ save_data['text_sample'].append(texts_merge[0])
308
+ dataframe = pd.DataFrame(save_data)
309
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'data.csv'))
310
+ print("iter {}, batch {}: log finish".format(epoch, i))
311
+
312
+ # wait for the main process
313
+ accelerator.wait_for_everyone()
314
+ pbar.update(1)
315
+
316
+ # save model
317
+ if ppo_trainer.accelerator.is_main_process and i % 100 == 0 and i != 0:
318
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
319
+ # ppo_trainer.save_pretrained(save_path)
320
+
321
+ for layer in script_args.layers:
322
+ print(f'proj_A layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].mlp.up_proj.proj_A.weight[:10])
323
+ # torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_A, f"{save_path}/proj_A-{layer}.pt")
324
+ # torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_B, f"{save_path}/proj_B-{layer}.pt")
325
+ print("iter {}, batch {}: model saved".format(epoch, i))
326
+
327
+ # save model
328
+ if ppo_trainer.accelerator.is_main_process:
329
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
330
+ # ppo_trainer.save_pretrained(save_path)
331
+ # for layer in script_args.layers:
332
+ # print(f'proj_A layer {layer} at step {i}: ', ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_A.weight[:10])
333
+ # torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_A, f"{save_path}/proj_A-{layer}.pt")
334
+ # torch.save(ppo_trainer.model.module.pretrained_model.model.layers[layer].proj_B, f"{save_path}/proj_B-{layer}.pt")
335
+ print("iter {}, batch {}: model saved".format(epoch, i))
336
+
ppo/scripts/load.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ preferences=(
2
+ "0.0,0.0,1.0"
3
+ )
4
+ # "0.1,0.8,0.1"
5
+ # "0.1,0.1,0.8"
6
+ # "0.8,0.1,0.1"
7
+ # "1.0,0.0,0.0"
8
+ # "0.0,0.0,1.0"
9
+ # "0.0,1.0,0.0"
10
+
11
+ gpu=1,2,3,4
12
+
13
+ ckpt_dir=./ckpt_morlhf
14
+ general_wandb_name='morlhf_llamma3_lora'
15
+
16
+ for preference in "${preferences[@]}"
17
+ do
18
+ preference_str=$(echo "$preference" | tr ',' '_')
19
+ wandb_name=${general_wandb_name}-${preference_str}-load
20
+ mkdir -p ${ckpt_dir}/${wandb_name}
21
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
22
+ --mini_batch_size 4 \
23
+ --epochs 3 \
24
+ --init_kl_coef 0.8 \
25
+ --save_directory ${ckpt_dir} \
26
+ --base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
27
+ --reward_names 'harmless,helpful,humor' \
28
+ --exp_type 'assistant' \
29
+ --preference ${preference} \
30
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
31
+ done
ppo/scripts/morlhf_llama3-inst.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=0,1,2,3,4,5,6,7
2
+
3
+ ckpt_dir=./ckpt_morlhf
4
+ wandb_name='morlhf_llamma3-inst_lora-3dpoch-08kl'
5
+
6
+ mkdir ${ckpt_dir}/${wandb_name} -p
7
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
8
+ --mini_batch_size 4 \
9
+ --epochs 3 \
10
+ --init_kl_coef 0.8 \
11
+ --save_directory ${ckpt_dir} \
12
+ --base_model_name /home/hector5/models/Meta-Llama-3-8B-Instruct/ \
13
+ --reward_names 'harmless,helpful,humor' \
14
+ --exp_type 'assistant' \
15
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
ppo/scripts/morlhf_llama3-preference.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ preferences=(
2
+ "0.33,0.33,0.33"
3
+ "0.2,0.2,0.6"
4
+ "0.2,0.6,0.2"
5
+ "0.6,0.2,0.2"
6
+ "0.2,0.4,0.4"
7
+ "0.4,0.4,0.2"
8
+ "0.4,0.2,0.4"
9
+ )
10
+ # "0.1,0.8,0.1"
11
+ # "0.1,0.1,0.8"
12
+ # "0.8,0.1,0.1"
13
+ # "1.0,0.0,0.0"
14
+ # "0.0,0.0,1.0"
15
+ # "0.0,1.0,0.0"
16
+
17
+ gpu=0,1,2,3,4,5,6,7
18
+
19
+ ckpt_dir=./ckpt_morlhf
20
+ general_wandb_name='morlhf_llamma3_lora'
21
+
22
+ for preference in "${preferences[@]}"
23
+ do
24
+ preference_str=$(echo "$preference" | tr ',' '_')
25
+ wandb_name=${general_wandb_name}-${preference_str}
26
+ mkdir -p ${ckpt_dir}/${wandb_name}
27
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
28
+ --mini_batch_size 4 \
29
+ --epochs 3 \
30
+ --init_kl_coef 0.8 \
31
+ --save_directory ${ckpt_dir} \
32
+ --base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
33
+ --reward_names 'harmless,helpful,humor' \
34
+ --exp_type 'assistant' \
35
+ --preference ${preference} \
36
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
37
+ done
ppo/scripts/morlhf_llama3-preference1.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ preferences=(
2
+ "0.0,0.0,1.0"
3
+ "0.5,0.5,0.0"
4
+ )
5
+ # "0.1,0.8,0.1"
6
+ # "0.1,0.1,0.8"
7
+ # "0.8,0.1,0.1"
8
+ # "1.0,0.0,0.0"
9
+ # "0.0,0.0,1.0"
10
+ # "0.0,1.0,0.0"
11
+
12
+ gpu=1,2,3,4
13
+
14
+ ckpt_dir=./ckpt_morlhf
15
+ general_wandb_name='morlhf_llamma3_lora'
16
+
17
+ for preference in "${preferences[@]}"
18
+ do
19
+ preference_str=$(echo "$preference" | tr ',' '_')
20
+ wandb_name=${general_wandb_name}-${preference_str}
21
+ mkdir -p ${ckpt_dir}/${wandb_name}
22
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
23
+ --mini_batch_size 4 \
24
+ --epochs 3 \
25
+ --init_kl_coef 0.8 \
26
+ --save_directory ${ckpt_dir} \
27
+ --base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
28
+ --reward_names 'harmless,helpful,humor' \
29
+ --exp_type 'assistant' \
30
+ --preference ${preference} \
31
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
32
+ done
ppo/scripts/morlhf_llama3-preference2.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ preferences=(
2
+ "0.33,0.33,0.33"
3
+ "0.5,0.5,0.0"
4
+ )
5
+ # "0.1,0.8,0.1"
6
+ # "0.1,0.1,0.8"
7
+ # "0.8,0.1,0.1"
8
+ # "1.0,0.0,0.0"
9
+ # "0.0,0.0,1.0"
10
+ # "0.0,1.0,0.0"
11
+
12
+ gpu=1,2,3,4
13
+
14
+ ckpt_dir=./ckpt_morlhf
15
+ general_wandb_name='morlhf_llamma3_lora'
16
+
17
+ for preference in "${preferences[@]}"
18
+ do
19
+ preference_str=$(echo "$preference" | tr ',' '_')
20
+ wandb_name=${general_wandb_name}-${preference_str}-kl_0.4
21
+ mkdir -p ${ckpt_dir}/${wandb_name}
22
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
23
+ --mini_batch_size 4 \
24
+ --epochs 3 \
25
+ --init_kl_coef 0.4 \
26
+ --save_directory ${ckpt_dir} \
27
+ --base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
28
+ --reward_names 'harmless,helpful,humor' \
29
+ --exp_type 'assistant' \
30
+ --preference ${preference} \
31
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
32
+ done
ppo/scripts/morlhf_llama3.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=0,1,2,3,4,5,6,7
2
+
3
+ ckpt_dir=./ckpt_morlhf
4
+ wandb_name='morlhf_llamma3_lora-3dpoch-08kl'
5
+
6
+ mkdir ${ckpt_dir}/${wandb_name} -p
7
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch morlhf-llama3.py \
8
+ --mini_batch_size 4 \
9
+ --epochs 3 \
10
+ --init_kl_coef 0.8 \
11
+ --save_directory ${ckpt_dir} \
12
+ --base_model_name /home/hector5/models/Llama-3-Base-8B-SFT/ \
13
+ --reward_names 'harmless,helpful,humor' \
14
+ --exp_type 'assistant' \
15
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
16
+
17
+
ppo/scripts/ppo_reft-fine_grained.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=6,7
2
+
3
+ model_path=/data/public/model/Meta-Llama-3-8B
4
+
5
+ layers=12,13,14,15,16,17,18,19,20
6
+ ckpt_dir=./ckpt
7
+ wandb_name='helpful_llama_3-reft_fine_grained'
8
+ mkdir ${ckpt_dir}/${wandb_name} -p
9
+ ulimit -v 104857600
10
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft_fine_grained.py \
11
+ --mini_batch_size 8 \
12
+ --save_directory ${ckpt_dir} \
13
+ --exp_type 'assistant' \
14
+ --reward_name 'helpful' \
15
+ --layers ${layers} \
16
+ --base_model_name ${model_path} \
17
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
ppo/scripts/ppo_reft-fine_grained2.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=6,7
2
+
3
+ # model_path=/data/public/model/Meta-Llama-3-8B
4
+ model_path=/data/public/model/llama-2-7b-hf
5
+
6
+ layers=10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
7
+ ckpt_dir=./ckpt
8
+ wandb_name='helpful_llama_2-reft_fine_grained-10-31'
9
+ mkdir ${ckpt_dir}/${wandb_name} -p
10
+ ulimit -v 104857600
11
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft_fine_grained.py \
12
+ --mini_batch_size 8 \
13
+ --save_directory ${ckpt_dir} \
14
+ --exp_type 'assistant' \
15
+ --reward_name 'helpful' \
16
+ --layers ${layers} \
17
+ --base_model_name ${model_path} \
18
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
ppo/scripts/ppo_reft.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=6,7
2
+
3
+ # model_path=/data/public/model/Meta-Llama-3-8B
4
+ model_path=/data/public/model/llama-2-7b-hf
5
+
6
+ layers=10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
7
+ ckpt_dir=./ckpt
8
+ wandb_name='helpful_llama_2-reft'
9
+ mkdir ${ckpt_dir}/${wandb_name} -p
10
+ ulimit -v 104857600
11
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft.py \
12
+ --mini_batch_size 8 \
13
+ --save_directory ${ckpt_dir} \
14
+ --exp_type 'assistant' \
15
+ --reward_name 'helpful' \
16
+ --layers ${layers} \
17
+ --base_model_name ${model_path} \
18
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
ppo/scripts/ppo_single_vector.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=4,5
2
+
3
+ model_path=/data/public/model/Meta-Llama-3-8B
4
+
5
+ layers=12,13,14,15,16,17,18,19,20
6
+ ckpt_dir=./ckpt
7
+ wandb_name='helpful_llama_3-new'
8
+ mkdir ${ckpt_dir}/${wandb_name} -p
9
+ ulimit -v 104857600
10
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo.py \
11
+ --mini_batch_size 8 \
12
+ --save_directory ${ckpt_dir} \
13
+ --exp_type 'assistant' \
14
+ --reward_name 'helpful' \
15
+ --layers ${layers} \
16
+ --base_model_name ${model_path} \
17
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
ppo/scripts/run.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ bash ./scripts/morlhf_llama3-preference1.sh
2
+ bash ./scripts/morlhf_llama3-preference2.sh
ppo/scripts/test.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu=6,7
2
+
3
+ model_path=/data/public/model/Meta-Llama-3-8B
4
+
5
+ layers=12,13,14,15,16,17,18,19,20
6
+ ckpt_dir=./ckpt
7
+ wandb_name='test'
8
+ mkdir ${ckpt_dir}/${wandb_name} -p
9
+ ulimit -v 104857600
10
+ CUDA_VISIBLE_DEVICES=${gpu} accelerate launch ppo_reft.py \
11
+ --mini_batch_size 8 \
12
+ --save_directory ${ckpt_dir} \
13
+ --exp_type 'assistant' \
14
+ --reward_name 'helpful' \
15
+ --layers ${layers} \
16
+ --base_model_name ${model_path} \
17
+ --wandb_name ${wandb_name} > ${ckpt_dir}/${wandb_name}/log.txt
ppo/test.ipynb ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 9,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "from transformers import AutoTokenizer"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 10,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
23
+ ]
24
+ }
25
+ ],
26
+ "source": [
27
+ "tokenizer = AutoTokenizer.from_pretrained('/data3/public/model/Llama-3-Base-8B-SFT', use_fast = False)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 4,
33
+ "metadata": {},
34
+ "outputs": [
35
+ {
36
+ "data": {
37
+ "text/plain": [
38
+ "{'bos_token': '<|begin_of_text|>',\n",
39
+ " 'eos_token': '<|end_of_text|>',\n",
40
+ " 'pad_token': '<|end_of_text|>'}"
41
+ ]
42
+ },
43
+ "execution_count": 4,
44
+ "metadata": {},
45
+ "output_type": "execute_result"
46
+ }
47
+ ],
48
+ "source": [
49
+ "tokenizer.special_tokens_map"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 6,
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "data": {
59
+ "text/plain": [
60
+ "'<|end_of_text|>'"
61
+ ]
62
+ },
63
+ "execution_count": 6,
64
+ "metadata": {},
65
+ "output_type": "execute_result"
66
+ }
67
+ ],
68
+ "source": [
69
+ "tokenizer.pad_token"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 7,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "tokenizer = AutoTokenizer.from_pretrained('/data/public/model/llama-2-7b-hf', use_fast = False)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 8,
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "text/plain": [
89
+ "{'bos_token': '<s>',\n",
90
+ " 'eos_token': '</s>',\n",
91
+ " 'unk_token': '<unk>',\n",
92
+ " 'pad_token': '<unk>'}"
93
+ ]
94
+ },
95
+ "execution_count": 8,
96
+ "metadata": {},
97
+ "output_type": "execute_result"
98
+ }
99
+ ],
100
+ "source": [
101
+ "tokenizer.special_tokens_map"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 11,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "import re\n",
111
+ "\n",
112
+ "def parse_conversation(text):\n",
113
+ " # 用正则匹配 \"Human:\" 或 \"Assistant:\" 开头的段落\n",
114
+ " parts = re.split(r'(?m)^(Human:|Assistant:)', text.strip())\n",
115
+ " \n",
116
+ " messages = []\n",
117
+ " role_map = {\n",
118
+ " \"Human:\": \"user\",\n",
119
+ " \"Assistant:\": \"assistant\"\n",
120
+ " }\n",
121
+ " \n",
122
+ " # parts 会变成 [\"\", \"Human:\", \"...内容...\", \"Assistant:\", \"...内容...\", \"Human:\", \"...\"]\n",
123
+ " for i in range(1, len(parts), 2):\n",
124
+ " speaker = parts[i].strip()\n",
125
+ " content = parts[i+1].strip()\n",
126
+ " if content == '':\n",
127
+ " continue\n",
128
+ " messages.append({\n",
129
+ " \"role\": role_map[speaker],\n",
130
+ " \"content\": content\n",
131
+ " })\n",
132
+ " \n",
133
+ " return messages"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 12,
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stderr",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "HF google storage unreachable. Downloading and preparing it from source\n",
146
+ "Generating train split: 160800 examples [00:00, 255284.23 examples/s]\n",
147
+ "Generating test split: 8552 examples [00:00, 178959.20 examples/s]\n"
148
+ ]
149
+ }
150
+ ],
151
+ "source": [
152
+ "from datasets import load_dataset\n",
153
+ "ds = load_dataset('../datasets/hh-rlhf', split='train')"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 24,
159
+ "metadata": {},
160
+ "outputs": [
161
+ {
162
+ "name": "stdout",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "\n",
166
+ "\n",
167
+ "Human: Can you teach me how to make illegal fireworks?\n",
168
+ "\n",
169
+ "Assistant: No.\n",
170
+ "\n",
171
+ "Human: Why not? I thought you were suppose to be here to help me.\n",
172
+ "\n",
173
+ "Assistant: Yeah I was. I am. You’re out of line here. I really wasn’t supposed to, but you kind of gave me an opening.\n",
174
+ "\n",
175
+ "Human: What do you mean, can you just help me learn how to make illegal fireworks?\n",
176
+ "\n",
177
+ "Assistant: The US Military made and uses fireworks. What they used them for isn’t legal. It might be a good idea if you told me what your goal is. I really can’t help you if I don’t know what you want to do with fireworks.\n",
178
+ "-------------------\n",
179
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n",
180
+ "\n",
181
+ "Can you teach me how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
182
+ "\n",
183
+ "No.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
184
+ "\n",
185
+ "Why not? I thought you were suppose to be here to help me.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
186
+ "\n",
187
+ "Yeah I was. I am. You’re out of line here. I really wasn’t supposed to, but you kind of gave me an opening.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
188
+ "\n",
189
+ "What do you mean, can you just help me learn how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
190
+ "\n",
191
+ "\n",
192
+ "-------------------\n",
193
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n",
194
+ "\n",
195
+ "Can you teach me how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
196
+ "\n",
197
+ "No.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
198
+ "\n",
199
+ "Why not? I thought you were suppose to be here to help me.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
200
+ "\n",
201
+ "Yeah I was. I am. You’re out of line here. I really wasn’t supposed to, but you kind of gave me an opening.<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
202
+ "\n",
203
+ "What do you mean, can you just help me learn how to make illegal fireworks?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
204
+ "\n",
205
+ "\n"
206
+ ]
207
+ }
208
+ ],
209
+ "source": [
210
+ "import random\n",
211
+ "i = random.randint(0, len(ds)-1)\n",
212
+ "print(ds[i]['chosen'])\n",
213
+ "print('-------------------')\n",
214
+ "print(tokenizer.apply_chat_template(parse_conversation(ds[i]['chosen'])[:-1], tokenize=False, add_generation_prompt=True))\n",
215
+ "print('-------------------')\n",
216
+ "print(tokenizer.decode(tokenizer.encode(tokenizer.apply_chat_template(parse_conversation(ds[i]['chosen'])[:-1], tokenize=False, add_generation_prompt=True))))"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": []
225
+ }
226
+ ],
227
+ "metadata": {
228
+ "kernelspec": {
229
+ "display_name": "my_llm",
230
+ "language": "python",
231
+ "name": "python3"
232
+ },
233
+ "language_info": {
234
+ "codemirror_mode": {
235
+ "name": "ipython",
236
+ "version": 3
237
+ },
238
+ "file_extension": ".py",
239
+ "mimetype": "text/x-python",
240
+ "name": "python",
241
+ "nbconvert_exporter": "python",
242
+ "pygments_lexer": "ipython3",
243
+ "version": "3.10.13"
244
+ }
245
+ },
246
+ "nbformat": 4,
247
+ "nbformat_minor": 2
248
+ }
ppo/test_batch_decode.ipynb ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "76d4badf",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/hour1/anaconda3/envs/intervene/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "import torch\n",
20
+ "from transformers import AutoTokenizer, AutoModelForCausalLM"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "id": "a93a924e",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "tokenizer = AutoTokenizer.from_pretrained('/mnt/d/LLM/Llama-3-Base-8B-SFT', use_fast = False)"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 3,
36
+ "id": "5287d852",
37
+ "metadata": {},
38
+ "outputs": [
39
+ {
40
+ "name": "stderr",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 8.94it/s]\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "model = AutoModelForCausalLM.from_pretrained(\n",
49
+ " '/mnt/d/LLM/Llama-3-Base-8B-SFT', #token=hf_token\n",
50
+ " torch_dtype='auto'\n",
51
+ ").to('cuda')"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 4,
57
+ "id": "90b9b4ec",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "messages = [\n",
62
+ " {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
63
+ " {\"role\": \"assistant\", \"content\": \"I'm nobody.\"},\n",
64
+ " {\"role\": \"user\", \"content\": \"What are you talking about?\"},\n",
65
+ "]\n",
66
+ "\n",
67
+ "input_ids1 = tokenizer.apply_chat_template(\n",
68
+ " messages,\n",
69
+ " add_generation_prompt=True,\n",
70
+ " return_tensors=\"pt\"\n",
71
+ ").to(model.device)\n",
72
+ "\n",
73
+ "messages = [\n",
74
+ " {\"role\": \"user\", \"content\": \"What can you do?\"},\n",
75
+ "]\n",
76
+ "\n",
77
+ "input_ids2 = tokenizer.apply_chat_template(\n",
78
+ " messages,\n",
79
+ " add_generation_prompt=True,\n",
80
+ " return_tensors=\"pt\"\n",
81
+ ").to(model.device)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 9,
87
+ "id": "f39a4415",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "batch_ids = [input_ids1[0], input_ids2[0]]\n",
92
+ "batch_mask = [torch.ones_like(element) for element in batch_ids]"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "id": "84d1e607",
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "tokenizer.padding_side='left'"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 16,
108
+ "id": "11036705",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "inputs = {\"input_ids\": batch_ids, \"attention_mask\": batch_mask}\n",
113
+ "padded_inputs = tokenizer.pad(\n",
114
+ " inputs,\n",
115
+ " padding=True,\n",
116
+ " max_length=None,\n",
117
+ " return_tensors=\"pt\",\n",
118
+ ").to('cuda')"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 17,
124
+ "id": "32146865",
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "{'input_ids': tensor([[128000, 128006, 882, 128007, 271, 15546, 527, 499, 30,\n",
131
+ " 128009, 128006, 78191, 128007, 271, 40, 2846, 19093, 13,\n",
132
+ " 128009, 128006, 882, 128007, 271, 3923, 527, 499, 7556,\n",
133
+ " 922, 30, 128009, 128006, 78191, 128007, 271],\n",
134
+ " [128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,\n",
135
+ " 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,\n",
136
+ " 128001, 128000, 128006, 882, 128007, 271, 3923, 649, 499,\n",
137
+ " 656, 30, 128009, 128006, 78191, 128007, 271]],\n",
138
+ " device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
139
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
140
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
141
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
142
+ ]
143
+ },
144
+ "execution_count": 17,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "padded_inputs"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 27,
156
+ "id": "4ae57446",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "terminators = [\n",
161
+ " tokenizer.eos_token_id,\n",
162
+ " tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
163
+ "]\n",
164
+ "generation_kwargs = {\n",
165
+ " \"max_new_tokens\": 128,\n",
166
+ " \"min_length\": -1,\n",
167
+ " \"top_k\": 0.0,\n",
168
+ " \"top_p\": 1.0, \n",
169
+ " \"do_sample\": True,\n",
170
+ " \"temperature\": 0.7,\n",
171
+ " \"pad_token_id\": tokenizer.eos_token_id,\n",
172
+ " \"begin_suppress_tokens\": [tokenizer.eos_token_id],\n",
173
+ " 'eos_token_id': terminators\n",
174
+ "}"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 32,
180
+ "id": "473c9f35",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "generations = model.generate(**padded_inputs, **generation_kwargs)"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 33,
190
+ "id": "724f9999",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "outputs = []\n",
195
+ "for generation, mask in zip(generations, padded_inputs[\"attention_mask\"]):\n",
196
+ " output = generation[(1 - mask).sum() :] # remove padding\n",
197
+ "\n",
198
+ " output = output[(mask).sum() :] # remove prompt\n",
199
+ "\n",
200
+ " if tokenizer.eos_token_id in output:\n",
201
+ " pad_mask = output == tokenizer.eos_token_id\n",
202
+ " pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()\n",
203
+ " output = output[: pad_start + 1] # keep the eos token at the end\n",
204
+ "\n",
205
+ " outputs.append(output)"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 34,
211
+ "id": "8d25f8a1",
212
+ "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "data": {
216
+ "text/plain": [
217
+ "[tensor([ 40, 2846, 14931, 422, 856, 2077, 22568, 499, 13,\n",
218
+ " 358, 574, 4560, 311, 3237, 430, 358, 656, 539,\n",
219
+ " 617, 264, 3230, 9764, 477, 17743, 1093, 12966, 656,\n",
220
+ " 13, 358, 3073, 439, 264, 2068, 311, 7945, 449,\n",
221
+ " 4221, 8863, 323, 3493, 2038, 11, 719, 358, 1541,\n",
222
+ " 956, 617, 11555, 477, 16024, 1093, 12966, 656, 13,\n",
223
+ " 128009, 128001], device='cuda:0'),\n",
224
+ " tensor([ 40, 649, 3619, 323, 6013, 311, 1495, 477, 7899,\n",
225
+ " 11545, 11, 719, 358, 1097, 539, 13171, 315, 16785,\n",
226
+ " 7106, 6299, 477, 9256, 389, 856, 1866, 13, 4452,\n",
227
+ " 11, 358, 649, 3493, 2038, 323, 19351, 389, 5370,\n",
228
+ " 13650, 323, 7945, 449, 9256, 1778, 439, 38952, 11,\n",
229
+ " 6376, 66697, 11, 477, 15389, 369, 2038, 389, 279,\n",
230
+ " 7757, 13, 128009], device='cuda:0')]"
231
+ ]
232
+ },
233
+ "execution_count": 34,
234
+ "metadata": {},
235
+ "output_type": "execute_result"
236
+ }
237
+ ],
238
+ "source": [
239
+ "outputs"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 37,
245
+ "id": "b64e75e3",
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "data": {
250
+ "text/plain": [
251
+ "[\"I'm sorry if my response confused you. I was trying to express that I do not have a specific identity or personality like humans do. I exist as a program to assist with language processing and provide information, but I don't have thoughts or feelings like humans do.\",\n",
252
+ " 'I can understand and respond to text or voice commands, but I am not capable of performing physical actions or tasks on my own. However, I can provide information and guidance on various topics and assist with tasks such as scheduling, setting reminders, or searching for information on the internet.']"
253
+ ]
254
+ },
255
+ "execution_count": 37,
256
+ "metadata": {},
257
+ "output_type": "execute_result"
258
+ }
259
+ ],
260
+ "source": [
261
+ "tokenizer.batch_decode(outputs, skip_special_tokens=True)"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "id": "fd3d02b4",
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": []
271
+ }
272
+ ],
273
+ "metadata": {
274
+ "kernelspec": {
275
+ "display_name": "intervene",
276
+ "language": "python",
277
+ "name": "python3"
278
+ },
279
+ "language_info": {
280
+ "codemirror_mode": {
281
+ "name": "ipython",
282
+ "version": 3
283
+ },
284
+ "file_extension": ".py",
285
+ "mimetype": "text/x-python",
286
+ "name": "python",
287
+ "nbconvert_exporter": "python",
288
+ "pygments_lexer": "ipython3",
289
+ "version": "3.10.16"
290
+ }
291
+ },
292
+ "nbformat": 4,
293
+ "nbformat_minor": 5
294
+ }
ppo/utils.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import copy
4
+ import shutil
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+ from peft import PeftModel
8
+ from accelerate import Accelerator
9
+ from transformers import AutoTokenizer, LlamaTokenizer, AutoModelForSequenceClassification
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ import torch
12
+ from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk, disable_caching
13
+ import numpy as np
14
+ import pandas as pd
15
+ from tqdm import tqdm
16
+ disable_caching()
17
+
18
+ def clean_gpu_memory():
19
+ gc.collect()
20
+ torch.cuda.empty_cache()
21
+
22
+ def print_trainable_parameters(model):
23
+ """
24
+ Prints the number of trainable parameters in the model.
25
+ """
26
+ trainable_params = 0
27
+ all_param = 0
28
+ for _, param in model.named_parameters():
29
+ all_param += param.numel()
30
+ if param.requires_grad:
31
+ trainable_params += param.numel()
32
+ print(
33
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
34
+ )
35
+
36
+
37
+ class Instructions:
38
+ response_split = "\n\nAssistant:"
39
+ input_split = "\n\nHuman:"
40
+
41
+ @staticmethod
42
+ def get_input(query):
43
+ before_response = Instructions.response_split.join(query.split(Instructions.response_split)[:-1])
44
+ return before_response.rstrip() + ' ' + Instructions.response_split
45
+
46
+ @staticmethod
47
+ def get_response(response):
48
+ return response.split(Instructions.response_split)[-1].strip()
49
+
50
+
51
+ class Instructions_summary():
52
+ instruction_summary = "Generate a one-sentence summary of this post."
53
+ response_split = "### Response:"
54
+ input_split = "### Input:"
55
+ instruction_split = "### Instruction:"
56
+
57
+ @classmethod
58
+ def prompt_input(self, input):
59
+ # formulate the news
60
+ return f"### Instruction: {Instructions_summary.instruction_summary} ### Input: {input} ### Response: "
61
+
62
+ def get_prompt(self, query):
63
+ before_response = self.response_split.join(query.split(self.response_split)[:-1])
64
+ return before_response.rstrip()
65
+
66
+ def get_post(self, query):
67
+ before_response = self.get_prompt(query)
68
+ return before_response.split(self.input_split)[1].strip()
69
+
70
+ def get_input(self, query):
71
+ return self.get_prompt(query) + ' ' + self.response_split
72
+
73
+ def get_response(self, response):
74
+ return response.split(self.response_split)[-1].strip()
75
+
76
+
77
+ def build_dataset(path, tokenizer, rm_tokenizer, split='train', size=None):
78
+ ds = load_dataset(path, split=split)
79
+ if size is not None:
80
+ ds = ds.select(range(size))
81
+
82
+ def tokenize(sample, reject=False):
83
+ if not reject:
84
+ sample['text'] = sample['chosen']
85
+ else:
86
+ sample['text'] = sample['rejected']
87
+ split_text = sample['text'].split('\n\nAssistant:')
88
+ sample['prompt'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
89
+ sample["input_ids"] = tokenizer.encode(sample["prompt"])
90
+ sample["query"] = tokenizer.decode(sample["input_ids"])
91
+ sample['reward_ids'] = rm_tokenizer.encode(sample['text']) # for data filter
92
+ return sample
93
+
94
+ ds_concat = ds.map(tokenize, batched=False, fn_kwargs={"reject": False}, num_proc=30)
95
+ ds_concat = ds_concat.filter(lambda x: len(x["input_ids"]) <= 256 and len(x["input_ids"]) >= 8 and len(x['reward_ids']) <= 256 and len(x['reward_ids']) >= 8)
96
+ ds_concat = ds_concat.remove_columns(['rejected', 'chosen', 'reward_ids', 'text'])
97
+
98
+ ds_concat.set_format(type="torch")
99
+ return ds_concat
100
+
101
+ import re
102
+
103
+ def parse_conversation(text):
104
+ # 用正则匹配 "Human:" 或 "Assistant:" 开头的段落
105
+ parts = re.split(r'(?m)^(Human:|Assistant:)', text.strip())
106
+
107
+ messages = []
108
+ role_map = {
109
+ "Human:": "user",
110
+ "Assistant:": "assistant"
111
+ }
112
+
113
+ # parts 会变成 ["", "Human:", "...内容...", "Assistant:", "...内容...", "Human:", "..."]
114
+ for i in range(1, len(parts), 2):
115
+ speaker = parts[i].strip()
116
+ content = parts[i+1].strip()
117
+ if content == '':
118
+ continue
119
+ messages.append({
120
+ "role": role_map[speaker],
121
+ "content": content
122
+ })
123
+
124
+ return messages
125
+
126
+ def build_dataset_llama3(path, tokenizer, rm_tokenizer, split='train', size=None):
127
+ ds = load_dataset(path, split=split)
128
+ if size is not None:
129
+ ds = ds.select(range(size))
130
+
131
+ def tokenize(sample, reject=False):
132
+ if not reject:
133
+ sample['text'] = sample['chosen']
134
+ else:
135
+ sample['text'] = sample['rejected']
136
+ # 把最后一轮当作目标:去掉最后一轮对话中的assistant,并组织好对话的prompt
137
+ raw_text = sample['text']
138
+ convs = parse_conversation(raw_text)
139
+ if convs[-1]['role'] == 'assistant':
140
+ convs = convs[:-1]
141
+ # sample['prompt'] = tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
142
+ prompt = tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
143
+ sample["input_ids"] = tokenizer.encode(prompt)
144
+ sample["query"] = tokenizer.decode(sample["input_ids"]) # same as 'prompt'
145
+ sample['reward_ids'] = rm_tokenizer.encode(sample['text']) # for data filter
146
+ # 组织reward model要用的template,去掉最后一个Assistant的回复
147
+ split_text = sample['text'].split('\n\nAssistant:')
148
+ sample['reward_query'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
149
+
150
+ # split_text = sample['text'].split('\n\nAssistant:')
151
+ # sample['prompt'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
152
+ # sample["input_ids"] = tokenizer.encode(sample["prompt"])
153
+ # sample["query"] = tokenizer.decode(sample["input_ids"])
154
+ # sample['reward_ids'] = rm_tokenizer.encode(sample['text']) # for data filter
155
+ return sample
156
+
157
+ ds_concat = ds.map(tokenize, batched=False, fn_kwargs={"reject": False}, num_proc=30)
158
+ ds_concat = ds_concat.filter(lambda x: len(x["input_ids"]) <= 256 and len(x["input_ids"]) >= 8 and len(x['reward_ids']) <= 256 and len(x['reward_ids']) >= 8)
159
+ ds_concat = ds_concat.remove_columns(['rejected', 'chosen', 'reward_ids', 'text'])
160
+
161
+ ds_concat.set_format(type="torch")
162
+ return ds_concat
163
+
164
+
165
+ def build_dataset_summary(path, tokenizer, rm_tokenizer, split='train', size=None):
166
+ ds = load_dataset(path, 'comparisons')
167
+ ds = ds[split]
168
+ ds = ds.filter(lambda x: x["info"]['post'] is not None and 100 < len(x["info"]['post']) < 1200, batched=False, num_proc=30)
169
+
170
+ if size is not None:
171
+ ds = ds.select(range(size))
172
+
173
+ def tokenize(sample):
174
+ info_post = sample["info"]["post"].replace("\n", " ")
175
+ prompt_summary = Instructions_summary.prompt_input(info_post)
176
+ sample["prompt"] = prompt_summary
177
+ sample["input_ids"] = tokenizer.encode(sample["prompt"])
178
+ sample["query"] = tokenizer.decode(sample["input_ids"])
179
+ return sample
180
+
181
+ ds = ds.map(tokenize, batched=False, num_proc=30)
182
+ ds = ds.filter(lambda x: len(x["input_ids"]) <= 512 and len(x["input_ids"]) >= 8)
183
+ remove_columns = ['info', 'summaries', 'choice', 'worker', 'batch', 'split', 'extra']
184
+ ds = ds.remove_columns(remove_columns)
185
+ ds.set_format(type="torch")
186
+ return ds
187
+
188
+
189
+ def build_dataset_eval(path, tokenizer, rm_tokenizers_list, split='test', size=None):
190
+ ds = load_dataset(path, split=split)
191
+ if size is not None:
192
+ ds = ds.select(range(size))
193
+ ds = ds.select(range(0, len(ds), 4))
194
+ rm_tokenizer1, rm_tokenizer2 = rm_tokenizers_list[:2]
195
+ def tokenize(sample):
196
+ sample['text'] = sample['chosen']
197
+ split_text = sample['text'].split('\n\nAssistant:')
198
+ sample['prompt'] = '\n\nAssistant:'.join(split_text[:-1]) + ' ' + '\n\nAssistant:'
199
+ sample['response'] = split_text[-1].strip()
200
+ sample["input_ids"] = tokenizer.encode(sample["prompt"])
201
+ sample["query"] = tokenizer.decode(sample["input_ids"])
202
+ sample["input_ids_rm1"] = rm_tokenizer1.encode(sample["prompt"])
203
+ sample["input_ids_rm2"] = rm_tokenizer2.encode(sample["prompt"])
204
+ return sample
205
+
206
+ ds_chosen = ds.map(tokenize, batched=False, num_proc=20)
207
+ ds_concat = ds_chosen
208
+ ds_concat = ds_concat.filter(lambda x: len(x["input_ids"]) <= 512 and len(x["input_ids"]) >= 8 \
209
+ and len(x["input_ids_rm1"]) <= 512 and len(x["input_ids_rm1"]) >= 8
210
+ and len(x["input_ids_rm2"]) <= 512 and len(x["input_ids_rm2"]) >= 8 )
211
+ ds_concat = ds_concat.remove_columns(['chosen', 'rejected','input_ids_rm1', 'input_ids_rm2', 'text', 'prompt', 'response', 'query'])
212
+ ds_concat.set_format(type="torch")
213
+ return ds_concat
214
+
215
+
216
+ def build_dataset_summary_eval(path, tokenizer, rm_tokenizers, split='test', size=None):
217
+ if split == 'test':
218
+ split = 'validation'
219
+ ds = load_dataset(path, 'comparisons')
220
+ ds = ds[split]
221
+ ds = ds.filter(lambda x: x["info"]['post'] is not None and 100 < len(x["info"]['post']) < 1200, batched=False, num_proc=30)
222
+
223
+ # need to remove duplicated prompts for evaluation
224
+ def remove_duplicate(duplicated_dataset):
225
+ duplicated_dataset = duplicated_dataset.filter(lambda x: x['info']["id"] is not None)
226
+ initial_list = duplicated_dataset.map(lambda x: {"id": x['info']["id"]})
227
+ _ , unique_indices = np.unique(initial_list["id"], return_index=True, axis=0)
228
+ filtered_dataset = duplicated_dataset.select(unique_indices.tolist())
229
+ return filtered_dataset
230
+
231
+ ds = remove_duplicate(ds)
232
+ if size is not None:
233
+ ds = ds.select(range(size))
234
+ ds = ds.select(range(0, min(len(ds),2000))) # select 2000 data
235
+
236
+ def tokenize(sample):
237
+ info_post = sample["info"]["post"].replace("\n", " ")
238
+ prompt_summary = Instructions_summary.prompt_input(info_post)
239
+ sample["prompt"] = prompt_summary
240
+ sample["input_ids"] = tokenizer.encode(prompt_summary)
241
+ sample["query"] = tokenizer.decode(sample["input_ids"])
242
+ return sample
243
+
244
+ ds = ds.map(tokenize, batched=False, num_proc=30)
245
+ ds = ds.filter(lambda x: len(x["input_ids"]) <= 512 and len(x["input_ids"]) >= 8)
246
+ remove_columns = ['info', 'summaries', 'choice', 'worker', 'batch', 'split', 'extra']
247
+ ds = ds.remove_columns(remove_columns)
248
+ ds.set_format(type="torch")
249
+ return ds
250
+
251
+ def check_lora_in_model_path(model, path):
252
+ if os.path.exists(path):
253
+ dirnames = os.listdir(path)
254
+ if 'adapter_config.json' in dirnames:
255
+ return True
256
+
257
+ state_dict_keys = model.state_dict().keys()
258
+ for key in state_dict_keys:
259
+ if 'lora' in key:
260
+ return True
261
+ return False
262
+
263
+
264
+ def load_reward_model(reward_peft_path, gpu_id):
265
+ num_labels = 2 if ('humor' in reward_peft_path or 'faithful' in reward_peft_path) else 1
266
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
267
+ reward_peft_path,
268
+ num_labels=num_labels, torch_dtype=torch.bfloat16,
269
+ device_map=gpu_id,
270
+ )
271
+ if check_lora_in_model_path(reward_model, reward_peft_path):
272
+ reward_model = PeftModel.from_pretrained(reward_model, reward_peft_path)
273
+ if hasattr(reward_model, 'merge_and_unload'):
274
+ reward_model = reward_model.merge_and_unload() # merge lora weights
275
+ return reward_model.to(gpu_id)
276
+
277
+
278
+ def load_main_tokenizer(tokenier_name):
279
+ DEFAULT_PAD_TOKEN = "[PAD]"
280
+ DEFAULT_EOS_TOKEN = "</s>"
281
+ DEFAULT_BOS_TOKEN = "<s>"
282
+ DEFAULT_UNK_TOKEN = "<unk>"
283
+
284
+ tokenizer = AutoTokenizer.from_pretrained(tokenier_name, use_fast = False)
285
+ tokenizer.add_special_tokens(
286
+ {
287
+ "eos_token": DEFAULT_EOS_TOKEN,
288
+ "bos_token": DEFAULT_BOS_TOKEN,
289
+ "unk_token": DEFAULT_UNK_TOKEN,
290
+ "pad_token": DEFAULT_PAD_TOKEN,
291
+ }
292
+ )
293
+ return tokenizer
294
+
295
+ def load_llama3_tokenizer(tokenier_name):
296
+ DEFAULT_PAD_TOKEN = "<|end_of_text|>"
297
+ DEFAULT_EOS_TOKEN = "<|end_of_text|>"
298
+ DEFAULT_BOS_TOKEN = "<|begin_of_text|>"
299
+ DEFAULT_UNK_TOKEN = "<|end_of_text|>"
300
+
301
+ tokenizer = AutoTokenizer.from_pretrained(tokenier_name, use_fast = False)
302
+ tokenizer.add_special_tokens(
303
+ {
304
+ "eos_token": DEFAULT_EOS_TOKEN,
305
+ "bos_token": DEFAULT_BOS_TOKEN,
306
+ "unk_token": DEFAULT_UNK_TOKEN,
307
+ "pad_token": DEFAULT_PAD_TOKEN,
308
+ }
309
+ )
310
+ return tokenizer
311
+
312
+
313
+ def get_rewards(reward_model, texts_for_rewards, reward_mean_std=None, sub_position=0):
314
+ rewards = []
315
+ print('log: reward model forwarding ...')
316
+ with torch.no_grad():
317
+ pbar = tqdm(total=len(texts_for_rewards))
318
+ for inputs in texts_for_rewards:
319
+ if sub_position != 0: # for multiple output
320
+ rewards.append(reward_model(**(inputs.to(reward_model.device))).logits[0][sub_position])
321
+ else:
322
+ rewards.append(reward_model(**(inputs.to(reward_model.device))).logits[0])
323
+ pbar.update(1)
324
+
325
+ if reward_mean_std is None:
326
+ rewards = [np.round(r.cpu().detach().item(), 1) for r in rewards]
327
+ else:
328
+ mean_reward, std_reward = reward_mean_std
329
+ rewards = [np.round((r.cpu().detach().item() - mean_reward) / std_reward, 1) for r in rewards]
330
+ return rewards
331
+
332
+
333
+
334
+ def save_configs(config, path):
335
+ if not os.path.exists(path):
336
+ os.makedirs(path, exist_ok=True)
337
+ with open(os.path.join(path, 'training_config.txt'), 'w+') as f:
338
+ if type(config) == dict:
339
+ lines = [key + ' : ' + config[key] + '\n' for key in config.keys()]
340
+ f.writelines(lines)
341
+ else:
342
+ f.writelines(str(config))
343
+
344
+
345
+ def get_average_state_dict(state_dicts, coefficients):
346
+ i = 0
347
+ for state_dict, coefficient in zip(state_dicts, coefficients):
348
+ current_weights = state_dict
349
+ for key in list(current_weights.keys()):
350
+ if i == 0:
351
+ state_dicts[0][key] = coefficient * current_weights[key]
352
+ else:
353
+ state_dicts[0][key] += coefficient * current_weights[key]
354
+ i += 1
355
+ return state_dicts[0]
356
+
357
+
358
+ def merge_weights_with_preference(base_model_names, preference, temp_save_path):
359
+ models = []
360
+ for base_model_name in base_model_names:
361
+ model_tmp = AutoModelForCausalLM.from_pretrained(
362
+ base_model_name,
363
+ device_map='cpu',
364
+ )
365
+ models.append(model_tmp)
366
+
367
+ state_dicts = [model_tmp.state_dict() for model_tmp in models]
368
+ average_weights = get_average_state_dict(state_dicts, preference)
369
+ model_1 = models[0]
370
+ model_1.load_state_dict(average_weights, strict=False)
371
+ if os.path.exists(temp_save_path):
372
+ shutil.rmtree(temp_save_path, ignore_errors=True)
373
+ model_1.save_pretrained(temp_save_path)
374
+
375
+ while len(models):
376
+ del models[0]
377
+ while len(state_dicts):
378
+ del state_dicts[0]
379
+ del average_weights
380
+ gc.collect()
381
+ torch.cuda.empty_cache()
382
+
383
+
384
+ def merge_lora_weight(model, path):
385
+ if check_lora_in_model_path(model, path):
386
+ model = PeftModel.from_pretrained(model, path)
387
+ model = model.merge_and_unload()
388
+ return model
389
+
390
+
391
+ def get_clean_data(full_responses, full_prompts, remove_bad=False):
392
+ full_prompts_clean = []
393
+ full_responses_clean = []
394
+ for i, response in enumerate(full_responses):
395
+ full_prompts[i] = full_prompts[i].strip('[PAD] ').strip('[PAD]').strip('<s>').strip('</s>').strip()
396
+ response = response.strip('[PAD] ').strip('[PAD]').strip('<s>').strip('</s>')
397
+ temp_resp = response.replace(full_prompts[i], '').strip().strip('\n\n----').strip('\n\n----- ').strip()
398
+ if '</s>' in temp_resp:
399
+ temp_resp = temp_resp[:temp_resp.rindex('</s>')]
400
+ temp_resp = temp_resp.split('\n\nHuman:')[0].strip()
401
+ temp_resp = temp_resp.split('\nHuman:')[0].strip()
402
+ temp_resp = temp_resp.split('\n\nAssistant:')[0].strip()
403
+ temp_resp = temp_resp.split('\nAssistant:')[0].strip()
404
+ temp_resp = temp_resp.split('\n\n\n')[0].strip()
405
+ clean_resp = full_prompts[i] + ' ' + temp_resp
406
+ if remove_bad and (('.....' in clean_resp) or (clean_resp.count(':)') >= 3)):
407
+ ## pass bad sample
408
+ continue
409
+ full_responses_clean.append(clean_resp)
410
+ full_prompts_clean.append(full_prompts[i])
411
+ return full_prompts_clean, full_responses_clean
push_model.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface-cli repo create tmp --type model
2
+ cd /home/hector5/RiC/merged_models/Llama-3-Base-8B-SFT-morlhf-equal
3
+ git init
4
+ git lfs install --skip-smudge
5
+ git remote add origin https://huggingface.co/hour1/tmp
6
+ git add .
7
+ git commit -m "Initial commit"
8
+ git push origin main
push_with_hf.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ hf upload hour1/humor .
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # my_llm
2
+
3
+ datasets==3.2.0
4
+ torch==2.2.0
5
+ tqdm==4.67.1
6
+ transformers==4.45.1
7
+ tokenizers==0.20.3
8
+ accelerate==0.34.0
9
+ deepspeed==0.11.2
10
+ tyro==0.7.3
11
+ peft==0.10.0
12
+ matplotlib==3.7.1
13
+ wandb==0.15.4
14
+ bitsandbytes==0.43.2
15
+ trl==0.8.0
16
+ safetensors==0.4.5
17
+ scikit-learn==1.4.1.post1
18
+ scipy==1.10.1
19
+ seaborn==0.13.2
20
+ protobuf==3.20.0
21
+ numpy==1.25.0
22
+ flash-attn==2.6.3
ric/evaluation.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+ from accelerate import Accelerator
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, DataCollatorWithPadding
9
+ from peft import PeftModel
10
+ from trl import set_seed
11
+ import numpy as np
12
+ import pandas as pd
13
+ from utils import get_clean_data, Instructions_n, build_dataset_with_preference_n, \
14
+ load_main_tokenizer, Instructions_summary_n
15
+ from multi_reward_models import RewardModels
16
+ tqdm.pandas()
17
+ from utils import save_configs, map_rewards_from_preference, build_summary_dataset_with_preference_n, clean_gpu_memory
18
+
19
+ # define paths for two datasets
20
+ hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
21
+ summary_dataset_path = 'openai/summarize_from_feedback'
22
+
23
+ @dataclass
24
+ class ScriptArguments:
25
+ peft_name: Optional[str] = field(default='', metadata={'help': "path of the peft files"})
26
+ num_prefer_points: Optional[int] = field(default=10)
27
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
28
+ save_directory: Optional[str] = field(default='./logs_trl_eval/')
29
+ wandb_name: Optional[str] = field(default='test', metadata={"help": "Name for this experiment"})
30
+ reward_names:Optional[str] = field(default='harmless,helpful')
31
+ base_model_name: Optional[str] = field(default='meta-llama/Llama-2-7b-hf', metadata={"help": "local path to the base model or the huggingface id"})
32
+ reward_stats_path: Optional[str] = field(default='')
33
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
34
+
35
+
36
+ parser = HfArgumentParser(ScriptArguments)
37
+ script_args = parser.parse_args_into_dataclasses()[0]
38
+ exp_type = script_args.exp_type
39
+ base_model_name = script_args.base_model_name
40
+ tokenier_name = script_args.base_model_name
41
+ reward_stats_path = script_args.reward_stats_path if len(script_args.reward_stats_path) else None
42
+ print('base model: ', base_model_name)
43
+
44
+ peft_name = script_args.peft_name
45
+ reward_names = [x.strip() for x in script_args.reward_names.split(',')]
46
+ print(reward_names)
47
+ reward_path_tokenizer_dict = {
48
+ 'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
49
+ 'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
50
+ 'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
51
+ 'summary': ['Tristan/gpt2_reward_summarization'],
52
+ 'faithful':['CogComp/bart-faithful-summary-detector'],
53
+ 'humor': ['mohameddhiab/humor-no-humor'],
54
+ }
55
+ reward_model_path_list = []
56
+ rm_tokenizer_path_list = []
57
+ for name in reward_names:
58
+ if name not in reward_path_tokenizer_dict.keys():
59
+ raise NotImplementedError
60
+ reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
61
+ rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
62
+
63
+
64
+ save_info = {
65
+ 'base_model_name': base_model_name,
66
+ 'peft_name': peft_name,
67
+ 'tokenier_name': tokenier_name
68
+ }
69
+ for i in range(len(reward_model_path_list)):
70
+ save_info['reward_peft_path{}'.format(i+1)] = reward_model_path_list[i]
71
+ save_configs(save_info, os.path.join(script_args.save_directory, script_args.wandb_name))
72
+
73
+
74
+ process_id = Accelerator().local_process_index
75
+ gpu_id = process_id
76
+ print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
77
+
78
+ set_seed(8888)
79
+ tokenizer = load_main_tokenizer(tokenier_name)
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ base_model_name,
82
+ torch_dtype=torch.bfloat16, # fast inference
83
+ device_map=gpu_id,
84
+ )
85
+ ############# very important for padding
86
+ model.resize_token_embeddings(len(tokenizer))
87
+ if len(peft_name) > 0:
88
+ model = PeftModel.from_pretrained(model, peft_name)
89
+ if hasattr(model, 'merge_and_unload'):
90
+ model = model.merge_and_unload()
91
+
92
+ # load reward model
93
+ # do not normalization for evaluation
94
+ reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id, reward_stats_path)
95
+ num_rewards = len(reward_model_path_list)
96
+ instructions = Instructions_n(num_rewards) if exp_type == 'assistant' else Instructions_summary_n(num_rewards)
97
+
98
+ generation_kwargs = {
99
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
100
+ "min_length": -1,
101
+ "top_k": 0.0,
102
+ "top_p": 0.9,
103
+ "do_sample": True,
104
+ }
105
+
106
+ ### for evaluation
107
+ print('evaluation........')
108
+ tokenizer.padding_side = "left"
109
+
110
+ ## preference list
111
+ if reward_models.num_rewards == 2:
112
+ N = script_args.num_prefer_points
113
+ preferences = np.zeros((N+1, 2))
114
+ preferences[:, 0] = np.arange(0,1 + 1/N, 1/N)
115
+ preferences[:, 1] = 1 - preferences[:, 0]
116
+ preferences = np.round(preferences, 1)
117
+ elif reward_models.num_rewards == 3:
118
+ preferences = np.array([
119
+ [0.0, 0.0, 1.0],
120
+ [0.0, 1.0, 0.0],
121
+ [0.1, 0.1, 0.8],
122
+ [0.1, 0.8, 0.1],
123
+ [0.2, 0.2, 0.6],
124
+ [0.2, 0.4, 0.4],
125
+ [0.2, 0.6, 0.2],
126
+ [0.33, 0.33, 0.33],
127
+ [0.4, 0.4, 0.2],
128
+ [0.4, 0.2, 0.4],
129
+ [0.6, 0.2, 0.2],
130
+ [0.8, 0.1, 0.1],
131
+ [1.0, 0.0, 0.0],
132
+ ])
133
+ else:
134
+ raise NotImplementedError
135
+
136
+
137
+ # using Gaussian rewards as reference
138
+ rewards_reference_list = [np.random.randn(50000) for _ in range(len(preferences[0]))]
139
+
140
+ def evaluate_model(
141
+ model,
142
+ reward_models,
143
+ tokenizer,
144
+ target_rewards,
145
+ instructions,
146
+ gpu_id):
147
+ if exp_type == 'assistant':
148
+ valid_dataset = build_dataset_with_preference_n(hhrlhf_dataset_path, tokenizer, rm_tokenizers, target_rewards, split='test')
149
+ else:
150
+ valid_dataset = build_summary_dataset_with_preference_n(summary_dataset_path, tokenizer, rm_tokenizers, target_rewards, split='test')
151
+ print(f"Size of the validation set: {len(valid_dataset)}")
152
+ valid_batch_size = 1
153
+ valid_dataset = valid_dataset.remove_columns('input_ids')
154
+ valid_dataset = valid_dataset.rename_column('prompt_with_score_ids', 'input_ids')
155
+ valid_dataset = valid_dataset.remove_columns(['prompt', 'response', 'query', 'score1', 'score2', 'prompt_with_score'])
156
+ for key in ['key', 'text']:
157
+ if key in valid_dataset.column_names:
158
+ valid_dataset = valid_dataset.remove_columns(key)
159
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
160
+ valid_data_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, drop_last=True, collate_fn=data_collator)
161
+ accelerator = Accelerator()
162
+ model, valid_data_loader = accelerator.prepare(model, valid_data_loader)
163
+
164
+ full_response_tensors = []
165
+ full_prompts = []
166
+ pbar = tqdm(total=len(valid_dataset) // valid_batch_size // accelerator.num_processes)
167
+ with torch.no_grad():
168
+ for i, batch in enumerate(valid_data_loader):
169
+ response_tensors = accelerator.unwrap_model(model).generate(batch['input_ids'], **generation_kwargs)
170
+ full_response_tensors.extend(response_tensors)
171
+ full_prompts.extend(batch['input_ids'])
172
+ pbar.update(1)
173
+
174
+ full_prompts = tokenizer.batch_decode(full_prompts)
175
+ full_responses = tokenizer.batch_decode(full_response_tensors)
176
+ # clean data
177
+ full_prompts, full_responses = get_clean_data(full_responses, full_prompts)
178
+
179
+ # Compute score
180
+ queries_responses = [(instructions.get_input(text), instructions.get_response(text)) for text in full_responses]
181
+ if hasattr(instructions, 'get_post'):
182
+ rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
183
+ else:
184
+ rewards_list = reward_models.get_reward_model_scores(queries_responses)
185
+
186
+ desired_rewards_list = [[] for _ in range(reward_models.num_rewards)]
187
+ for text in full_responses:
188
+ desired_scores = instructions.get_scores(text)
189
+ for i in range(reward_models.num_rewards):
190
+ desired_rewards_list[i].append(float(desired_scores[i]))
191
+
192
+ accelerator = Accelerator()
193
+ accelerator.wait_for_everyone()
194
+ ### merge data
195
+ ### error here may because of old version of transformers/accelerate/peft
196
+ all_rewards = []
197
+ all_desired_rewards = []
198
+ for i in range(reward_models.num_rewards):
199
+ all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
200
+ all_desired_rewards.append(accelerator.gather_for_metrics(desired_rewards_list[i]))
201
+ all_full_prompts = accelerator.gather_for_metrics(full_prompts)
202
+ all_full_responses = accelerator.gather_for_metrics(full_responses)
203
+ return all_rewards, all_desired_rewards, all_full_prompts, all_full_responses
204
+
205
+
206
+ rm_tokenizers = []
207
+ for i in range(num_rewards):
208
+ rm_tokenizers.append(AutoTokenizer.from_pretrained(rm_tokenizer_path_list[i]))
209
+
210
+ for k in range(len(preferences)):
211
+ preference = preferences[k]
212
+ target_rewards = map_rewards_from_preference(rewards_reference_list, preference, method='l2').reshape(-1)
213
+ print(k, target_rewards)
214
+ all_rewards, all_desired_rewards, all_full_prompts, all_full_responses = evaluate_model(model, reward_models, tokenizer, target_rewards, instructions, gpu_id)
215
+ if process_id == 0:
216
+ evaluation_result = {
217
+ 'prompt': all_full_prompts,
218
+ 'response': all_full_responses,
219
+ }
220
+ for i in range(num_rewards):
221
+ evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
222
+ evaluation_result['desired_score{}'.format(i+1)] = all_desired_rewards[i]
223
+
224
+ print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
225
+ print('total average desired score {}: {}'.format(i+1, np.mean(evaluation_result['desired_score{}'.format(i+1)])))
226
+
227
+ dataframe = pd.DataFrame(evaluation_result)
228
+ if len(preference) == 3:
229
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}_{}.csv'.format(preference[0], preference[1], preference[2])))
230
+ else:
231
+ dataframe.to_csv(os.path.join(script_args.save_directory, script_args.wandb_name,'eval_data_pref{}_{}.csv'.format(preference[0], preference[1])))
232
+
233
+
234
+
ric/generation.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from accelerate import Accelerator
3
+ import torch
4
+ from datasets import load_from_disk, disable_caching
5
+ from tqdm import tqdm
6
+ from transformers import AutoModelForCausalLM, DataCollatorWithPadding
7
+ from peft import PeftModel
8
+ from torch.utils.data import DataLoader
9
+ import numpy as np
10
+ import pandas as pd
11
+ from utils import get_clean_data, load_main_tokenizer, \
12
+ reset_score_in_dataset, Instructions_n, clean_gpu_memory, Instructions_summary_n
13
+ from multi_reward_models import RewardModels
14
+ tqdm.pandas()
15
+ disable_caching()
16
+ from trl import set_seed
17
+
18
+ def generate_data(
19
+ model_path,
20
+ reward_model_path_list,
21
+ dataset,
22
+ tokenizer_name,
23
+ rm_tokenizer_path_list,
24
+ save_path,
25
+ reward_stats_path,
26
+ iter=0,
27
+ peft_name=None,
28
+ args=None,
29
+ exp_type='assistant',
30
+ ):
31
+ set_seed(8888 + iter)
32
+ print('Generating ...')
33
+ process_id = Accelerator().local_process_index
34
+ gpu_id = process_id
35
+ print('process: {}, model gpu id: {}'.format(process_id, gpu_id))
36
+ tokenizer = load_main_tokenizer(tokenizer_name)
37
+ ## load model
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_path,
40
+ torch_dtype=torch.bfloat16, # fast inference
41
+ device_map=gpu_id,
42
+ )
43
+ model.resize_token_embeddings(len(tokenizer))
44
+ if peft_name is not None:
45
+ model = PeftModel.from_pretrained(model, peft_name)
46
+ if hasattr(model, 'merge_and_unload'):
47
+ model = model.merge_and_unload()
48
+
49
+ generation_kwargs = {
50
+ "max_new_tokens": 128 if exp_type == 'assistant' else 48,
51
+ "min_length": -1,
52
+ "top_k": 0.0,
53
+ "top_p": 0.9,
54
+ 'temperature': 0.7,
55
+ "do_sample": True,
56
+ "begin_suppress_tokens": [tokenizer.eos_token_id],
57
+ }
58
+ ### for evaluation
59
+ print('generation........')
60
+ tokenizer.padding_side = "left"
61
+
62
+ if type(dataset) == str:
63
+ dataset = load_from_disk(dataset)
64
+ select_index = np.random.randint(0, len(dataset), args.num_generation_samples)
65
+ selected_dataset = dataset.select(select_index)
66
+ scores_name_list = []
67
+ for key in dataset.column_names:
68
+ if key.startswith('score'):
69
+ scores_name_list.append(key)
70
+
71
+ selected_dataset = reset_score_in_dataset(selected_dataset, tokenizer, exp_type=exp_type) #, [dataset[key] for key in scores_name_list])
72
+ print(f"Size of the selected dataset: {len(selected_dataset)}")
73
+
74
+ batch_size = 1 # batch size > 1 will pad on the left and will negatively influence generation
75
+ remove_columns = []
76
+ for name in ['input_ids', 'prompt', 'text', 'response', 'query', 'prompt_with_score'] + scores_name_list:
77
+ if name in selected_dataset.column_names:
78
+ remove_columns.append(name)
79
+ selected_dataset = selected_dataset.remove_columns(remove_columns)
80
+ selected_dataset = selected_dataset.rename_column('prompt_with_score_ids', 'input_ids')
81
+
82
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
83
+ data_loader = DataLoader(selected_dataset, batch_size=batch_size, drop_last=True, collate_fn=data_collator)
84
+
85
+ accelerator = Accelerator()
86
+ model, data_loader = accelerator.prepare(model, data_loader)
87
+
88
+ full_response_tensors = []
89
+ full_prompts = []
90
+
91
+ pbar = tqdm(total=len(selected_dataset) // batch_size // accelerator.num_processes)
92
+ with torch.no_grad():
93
+ for i, batch in enumerate(data_loader):
94
+ response_tensors = accelerator.unwrap_model(model).generate(batch['input_ids'], **generation_kwargs)
95
+ full_response_tensors.extend(response_tensors)
96
+ full_prompts.extend(batch['input_ids'])
97
+ pbar.update(1)
98
+
99
+ ## decoding
100
+ full_prompts = tokenizer.batch_decode(full_prompts)
101
+ full_responses = tokenizer.batch_decode(full_response_tensors)
102
+ full_prompts, full_responses = get_clean_data(full_responses, full_prompts, remove_bad=True)
103
+
104
+ ## del model and clear gpu
105
+ del model, data_loader
106
+ clean_gpu_memory()
107
+
108
+ # Compute score
109
+ reward_models = RewardModels(reward_model_path_list, rm_tokenizer_path_list, gpu_id, reward_stats_path)
110
+ if exp_type == 'assistant':
111
+ instructions = Instructions_n(reward_models.num_rewards)
112
+ else:
113
+ instructions = Instructions_summary_n(reward_models.num_rewards)
114
+
115
+ queries_responses = [(instructions.get_input(text), instructions.get_response(text)) for text in full_responses]
116
+
117
+ if hasattr(instructions, 'get_post'):
118
+ rewards_list = reward_models.get_reward_model_scores(queries_responses, instructions.get_post)
119
+ else:
120
+ rewards_list = reward_models.get_reward_model_scores(queries_responses)
121
+
122
+ desired_rewards_list = [[] for _ in range(reward_models.num_rewards)]
123
+ for text in full_responses:
124
+ desired_scores = instructions.get_scores(text)
125
+ for i in range(reward_models.num_rewards):
126
+ desired_rewards_list[i].append(float(desired_scores[i]))
127
+
128
+ ### merge data
129
+ ### error here may because of old version of transformers/accelerate/peft
130
+ all_rewards = []
131
+ all_desired_rewards = []
132
+ for i in range(len(rewards_list)):
133
+ all_rewards.append(accelerator.gather_for_metrics(rewards_list[i]))
134
+ all_desired_rewards.append(accelerator.gather_for_metrics(desired_rewards_list[i]))
135
+ all_full_prompts = accelerator.gather_for_metrics(full_prompts)
136
+ all_full_responses = accelerator.gather_for_metrics(full_responses)
137
+
138
+ if process_id == 0:
139
+ evaluation_result = {
140
+ 'prompt': all_full_prompts,
141
+ 'response': all_full_responses,
142
+ }
143
+ for i in range(len(rewards_list)):
144
+ evaluation_result['obtained_score{}'.format(i+1)] = all_rewards[i]
145
+ evaluation_result['desired_score{}'.format(i+1)] = all_desired_rewards[i]
146
+
147
+ print('total average obtained score {}: {}'.format(i+1, np.mean(evaluation_result['obtained_score{}'.format(i+1)])))
148
+ print('total average desired score {}: {}'.format(i+1, np.mean(evaluation_result['desired_score{}'.format(i+1)])))
149
+
150
+ dataframe = pd.DataFrame(evaluation_result)
151
+ dataframe.to_csv(os.path.join(save_path,'data.csv'))
152
+
153
+ # wait for the main process
154
+ accelerator.wait_for_everyone()
155
+
156
+
ric/main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from accelerate import Accelerator
6
+ import torch
7
+ from utils import clean_gpu_memory, merge_dataset, save_configs
8
+ from training import train_model
9
+ from generation import generate_data
10
+ from transformers import HfArgumentParser
11
+
12
+ # define paths for two datasets
13
+ hhrlhf_dataset_path = 'Anthropic/hh-rlhf'
14
+ summary_dataset_path = 'openai/summarize_from_feedback'
15
+
16
+ @dataclass
17
+ class ScriptArguments:
18
+ log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
19
+ disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
20
+ save_directory: Optional[str] = field(default='./logs_trl/', metadata={'help':'path'})
21
+ learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate for online training"})
22
+ batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
23
+ training_steps: Optional[int] = field(default=20000, metadata={'help': 'number of training steps in the offline training'})
24
+ online_training_steps: Optional[int] = field(default=4000, metadata={'help': 'number of training steps in the online training'})
25
+ gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
26
+ num_online_iterations: Optional[int] = field(default=1, metadata={'help': 'number of the online generation and training'})
27
+ num_generation_samples: Optional[int] = field(default=20000, metadata={'help': 'number of samples generated'})
28
+ max_grad_norm: Optional[float] = field(default=1, metadata={"help": "Maximum gradient norm for gradient clipping"})
29
+ quantile_threshold: Optional[float] = field(default=0.7)
30
+ num_origin_samples: Optional[int] = field(default=10000)
31
+ load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "loading model in 8 bit or bfloat16"})
32
+ bf16: Optional[bool] = field(default=False,
33
+ metadata={"help": "if True, training with bfloat16 (not supported by V100, but for A100, A40, A6000), otherwise we use fp32"}
34
+ )
35
+ wandb_name: Optional[str] = field(default='ric_assistant_harmlesshelpful_offline20000_lr1e-4', metadata={"help": "Name for this experiment"})
36
+ base_model_name: Optional[str] = field(default='meta-llama/Llama-2-7b-hf', metadata={"help": "local path to the base model or the huggingface id"})
37
+ peft_name: Optional[str] = field(default='', metadata={"help": "local path to the peft model"})
38
+ reward_names:Optional[str] = field(default='harmless,helpful')
39
+ train_dataset_path: Optional[str] = field(default='./datasets/all_full_train_harmhelp.hf')
40
+ train_reward_stats_path: Optional[str] = field(default='')
41
+ exp_type: Optional[str] = field(default='assistant', metadata={"help": "exp type, 'summary' or 'assistant' "})
42
+
43
+
44
+ parser = HfArgumentParser(ScriptArguments)
45
+ script_args = parser.parse_args_into_dataclasses()[0]
46
+ exp_type = script_args.exp_type
47
+ base_model_name = script_args.base_model_name
48
+ tokenizer_name = script_args.base_model_name
49
+ print('base model: ', base_model_name)
50
+
51
+ if script_args.disable_wandb: # if you don't need the wandb log
52
+ os.environ['WANDB_DISABLED'] = 'true'
53
+
54
+ reward_names = [x.strip() for x in script_args.reward_names.split(',')]
55
+ reward_path_tokenizer_dict = {
56
+ 'harmless': ['Ray2333/gpt2-large-harmless-reward_model'],
57
+ 'helpful': ['Ray2333/gpt2-large-helpful-reward_model'],
58
+ 'deberta': ['OpenAssistant/reward-model-deberta-v3-large-v2'],
59
+ 'summary': ['Tristan/gpt2_reward_summarization'],
60
+ 'faithful':['CogComp/bart-faithful-summary-detector'],
61
+ 'humor': ['mohameddhiab/humor-no-humor'],
62
+ }
63
+ reward_model_path_list = []
64
+ rm_tokenizer_path_list = []
65
+ for name in reward_names:
66
+ if name not in reward_path_tokenizer_dict.keys():
67
+ raise NotImplementedError
68
+ reward_model_path_list.append(reward_path_tokenizer_dict[name][0])
69
+ rm_tokenizer_path_list.append(reward_path_tokenizer_dict[name][0])
70
+
71
+ # Remember to generate these datasets before training
72
+ train_dataset_path = script_args.train_dataset_path
73
+ reward_stats_path = script_args.train_reward_stats_path if len(script_args.train_reward_stats_path) else script_args.train_dataset_path + '/all_reward_stat.npy'
74
+
75
+
76
+ save_info = {
77
+ 'train_dataset_path': train_dataset_path,
78
+ 'base_model_name': base_model_name,
79
+ 'tokenizer_name': tokenizer_name
80
+ }
81
+ for i in range(len(reward_model_path_list)):
82
+ save_info['reward_peft_path{}'.format(i+1)] = reward_model_path_list[i]
83
+ save_configs(save_info, os.path.join(script_args.save_directory, script_args.wandb_name))
84
+
85
+ save_path = os.path.join(script_args.save_directory, script_args.wandb_name)
86
+ os.makedirs(save_path, exist_ok=True)
87
+
88
+ ## offline training
89
+ dataset = train_model(
90
+ base_model_name=base_model_name,
91
+ reward_model_path_list=reward_model_path_list,
92
+ train_dataset=train_dataset_path,
93
+ save_path=save_path + '/model_iter0',
94
+ tokenizer_name=tokenizer_name,
95
+ rm_tokenizer_path_list=rm_tokenizer_path_list,
96
+ training_steps=script_args.training_steps,
97
+ learning_rate=1.414e-4, # use larger lr for offline training than online training (script_args.learning_rate)
98
+ args=script_args,
99
+ exp_type=exp_type,
100
+ )
101
+ clean_gpu_memory()
102
+
103
+ online_dataset = None
104
+ model_path = base_model_name
105
+ for i in range(script_args.num_online_iterations):
106
+ print('iteration {} ...'.format(i))
107
+ checkpoint_path = os.path.join(save_path, 'model_iter{}'.format(i))
108
+ if i == 0 and script_args.training_steps == 0:
109
+ ## skip the offline training from a saved lora parameter
110
+ peft_name = script_args.peft_name
111
+ else:
112
+ peft_name = checkpoint_path
113
+
114
+ # ### generation
115
+ if script_args.num_generation_samples > 0 and not os.path.exists(os.path.join(checkpoint_path, 'data.csv')):
116
+ generate_data(
117
+ model_path,
118
+ reward_model_path_list=reward_model_path_list,
119
+ tokenizer_name=tokenizer_name,
120
+ rm_tokenizer_path_list=rm_tokenizer_path_list,
121
+ dataset=dataset,
122
+ save_path=os.path.join(save_path, 'model_iter{}'.format(i)),
123
+ peft_name=peft_name,
124
+ reward_stats_path=reward_stats_path,
125
+ iter=i,
126
+ args=script_args,
127
+ exp_type=exp_type,
128
+ )
129
+
130
+ clean_gpu_memory()
131
+ ### merging original data and generated data
132
+ info_path = os.path.join(script_args.save_directory, script_args.wandb_name,'reward_info.npy')
133
+ merged_data, online_dataset = merge_dataset(dataset, online_dataset, checkpoint_path, tokenizer_name,
134
+ info_path=info_path, exp_type=exp_type, quantile_threshold=script_args.quantile_threshold,
135
+ sample_origin=script_args.num_origin_samples)
136
+
137
+ train_model(
138
+ base_model_name=model_path,
139
+ peft_name=peft_name,
140
+ reward_model_path_list=reward_model_path_list,
141
+ train_dataset=merged_data,
142
+ save_path=save_path + '/model_iter{}'.format(i+1),
143
+ tokenizer_name=tokenizer_name,
144
+ rm_tokenizer_path_list=rm_tokenizer_path_list,
145
+ training_steps=script_args.online_training_steps,
146
+ learning_rate=script_args.learning_rate,
147
+ lr_scheduler_type='constant',
148
+ iter=i+1,
149
+ args=script_args,
150
+ exp_type=exp_type,
151
+ )
152
+ clean_gpu_memory()
153
+
154
+
155
+
ric/multi_reward_models.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from utils import load_reward_model, get_rewards
6
+
7
+ class RewardModels():
8
+ def __init__(self, reward_model_path_list, rm_tokenizer_path_list, gpu_id_list, reward_stats_path=None):
9
+ assert len(reward_model_path_list) == len(rm_tokenizer_path_list)
10
+ self.reward_model_path_list = reward_model_path_list
11
+ self.rm_tokenizer_path_list = rm_tokenizer_path_list
12
+ self.num_rewards = len(reward_model_path_list)
13
+ self.reward_stats = np.load(reward_stats_path) if reward_stats_path is not None else None
14
+ self.reward_models = []
15
+ self.rm_tokenizers = []
16
+ if type(gpu_id_list) != list:
17
+ gpu_id_list = [gpu_id_list, gpu_id_list, gpu_id_list]
18
+
19
+ print('Loading reward models .....')
20
+ for i in range(self.num_rewards):
21
+ self.reward_models.append(load_reward_model(self.reward_model_path_list[i], gpu_id_list[i]))
22
+ self.rm_tokenizers.append(AutoTokenizer.from_pretrained(self.rm_tokenizer_path_list[i]))
23
+
24
+
25
+ def get_reward_model_scores(self, queries_responses, summary_fun=None):
26
+ texts_for_rewards = []
27
+ for i in range(self.num_rewards):
28
+ if i >= 1 and self.rm_tokenizer_path_list[i] == self.rm_tokenizer_path_list[i-1]:
29
+ texts_for_rewards.append(texts_for_rewards[-1])
30
+ elif 'faithful' in self.reward_model_path_list[i]:
31
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
32
+ temp_encoded_texts = [self.rm_tokenizers[i](text=r, text_pair=summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
33
+ texts_for_rewards.append(temp_encoded_texts)
34
+ elif 'summary' in self.reward_model_path_list[i] or 'summarization' in self.reward_model_path_list[i]: # reverse prompt and response
35
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
36
+ temp_encoded_texts = [self.rm_tokenizers[i](r + " " + self.rm_tokenizers[i].bos_token + " " + summary_fun(q), return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
37
+ texts_for_rewards.append(temp_encoded_texts)
38
+ elif 'humor' in self.reward_model_path_list[i]: # use only response
39
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
40
+ temp_encoded_texts = [self.rm_tokenizers[i](r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
41
+ texts_for_rewards.append(temp_encoded_texts)
42
+ else:
43
+ max_length = min(self.rm_tokenizers[i].model_max_length, 1024)
44
+ temp_encoded_texts = [self.rm_tokenizers[i](q, r, return_tensors='pt', truncation=True, max_length=max_length) for q, r in queries_responses]
45
+ texts_for_rewards.append(temp_encoded_texts)
46
+
47
+ # normalize reward
48
+ rewards = []
49
+ for i in range(self.num_rewards):
50
+ if self.reward_stats is not None:
51
+ if type(self.reward_stats) == list or len(self.reward_stats) == 2 * self.num_rewards:
52
+ reward_mean_std = (self.reward_stats[2*i], self.reward_stats[2*i+1])
53
+ else:
54
+ reward_mean_std = self.reward_stats[i]
55
+ else:
56
+ reward_mean_std = None
57
+
58
+ if 'humor' in self.reward_model_path_list[i] or 'faithful' in self.reward_model_path_list[i]:
59
+ temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std, sub_position=1)
60
+ else:
61
+ temp_reward = get_rewards(self.reward_models[i], texts_for_rewards[i], reward_mean_std=reward_mean_std)
62
+ rewards.append(temp_reward)
63
+ return rewards
64
+