anakin87 commited on
Commit
4db7146
1 Parent(s): 5cbf999
README.md CHANGED
@@ -1,5 +1,8 @@
1
  ---
2
  license: other
 
 
 
3
  base_model: google/gemma-2b
4
  tags:
5
  - trl
@@ -16,10 +19,37 @@ language:
16
 
17
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
18
  should probably proofread and complete it, then remove this comment. -->
19
-
20
  # gemma-2b-orpo
21
 
22
- This model is a fine-tuned version of [google/gemma-2b](https://huggingface.co/google/gemma-2b) on an unknown dataset..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  ## Model description
25
 
 
1
  ---
2
  license: other
3
+ license_name: gemma-terms-of-use
4
+ license_link: https://ai.google.dev/gemma/terms
5
+ library_name: transformers
6
  base_model: google/gemma-2b
7
  tags:
8
  - trl
 
19
 
20
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
21
  should probably proofread and complete it, then remove this comment. -->
22
+ ![image/png](./assets/gemma-2b-orpo.png)
23
  # gemma-2b-orpo
24
 
25
+ This is an ORPO fine-tune of [google/gemma-2b](https://huggingface.co/google/gemma-2b) with
26
+ [`alvarobartt/dpo-mix-7k-simplified`](https://huggingface.co/datasets/alvarobartt/dpo-mix-7k-simplified).
27
+
28
+ ## ORPO
29
+ [ORPO (Odds Ratio Preference Optimization)](https://arxiv.org/abs/2403.07691) is a new training paradigm that combines the usually separated phases
30
+ of SFT (Supervised Fine-Tuning) and Preference Alignment (usually performed with RLHF or simpler methods like DPO).
31
+ - Faster training
32
+ - Less memory usage (no reference model needed)
33
+ - Good results!
34
+
35
+ ## 🏆 Evaluation
36
+
37
+ ### Nous
38
+
39
+ gemma-2b-orpo performs well on Nous' benchmark suite (evaluation performed using [LLM AutoEval](https://github.com/mlabonne/llm-autoeval)).
40
+
41
+ | Model | Average | AGIEval | GPT4All | TruthfulQA | Bigbench |
42
+ |---|---:|---:|---:|---:|---:|
43
+ | [anakin87/gemma-2b-orpo](https://huggingface.co/anakin87/gemma-2b-orpo) [📄](./assets/gemma-2b-orpo-Nous.md) | 39.45 | 23.76 | 58.25 | 44.47 | 31.32 |
44
+ | [mlabonne/Gemmalpaca-2B](https://huggingface.co/mlabonne/Gemmalpaca-2B) [📄](https://gist.github.com/mlabonne/4b638752fc3227df566f9562064cb864) | 38.39 | 24.48 | 51.22 | 47.02 | 30.85 |
45
+ | [google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it) [📄](https://gist.github.com/mlabonne/db0761e74175573292acf497da9e5d95) | 36.1 | 23.76 | 43.6 | 47.64 | 29.41 |
46
+ | [google/gemma-2b](https://huggingface.co/google/gemma-2b) [📄](https://gist.github.com/mlabonne/7df1f238c515a5f63a750c8792cef59e) | 34.26 | 22.7 | 43.35 | 39.96 | 31.03 |
47
+
48
+
49
+ ## 🙏 Dataset
50
+ [`alvarobartt/dpo-mix-7k-simplified`](https://huggingface.co/datasets/alvarobartt/dpo-mix-7k-simplified)
51
+ is a simplified version of [`argilla/dpo-mix-7k`](https://huggingface.co/datasets/argilla/dpo-mix-7k).
52
+ You can find more information [here](https://huggingface.co/alvarobartt/Mistral-7B-v0.1-ORPO#about-the-dataset).
53
 
54
  ## Model description
55
 
assets/gemma-2b-orpo-Nous.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | Model |AGIEval|GPT4All|TruthfulQA|Bigbench|Average|
2
+ |--------------------------------------------------------------|------:|------:|---------:|-------:|------:|
3
+ |[gemma-2b-orpo](https://huggingface.co/anakin87/gemma-2b-orpo)| 23.76| 58.25| 44.47| 31.32| 39.45|
4
+
5
+ ### AGIEval
6
+ | Task |Version| Metric |Value| |Stderr|
7
+ |------------------------------|------:|--------|----:|---|-----:|
8
+ |agieval_aqua_rat | 0|acc |15.35|± | 2.27|
9
+ | | |acc_norm|17.32|± | 2.38|
10
+ |agieval_logiqa_en | 0|acc |25.96|± | 1.72|
11
+ | | |acc_norm|29.34|± | 1.79|
12
+ |agieval_lsat_ar | 0|acc |19.57|± | 2.62|
13
+ | | |acc_norm|20.00|± | 2.64|
14
+ |agieval_lsat_lr | 0|acc |23.14|± | 1.87|
15
+ | | |acc_norm|21.96|± | 1.83|
16
+ |agieval_lsat_rc | 0|acc |24.16|± | 2.61|
17
+ | | |acc_norm|24.54|± | 2.63|
18
+ |agieval_sat_en | 0|acc |29.61|± | 3.19|
19
+ | | |acc_norm|27.18|± | 3.11|
20
+ |agieval_sat_en_without_passage| 0|acc |30.58|± | 3.22|
21
+ | | |acc_norm|24.76|± | 3.01|
22
+ |agieval_sat_math | 0|acc |23.64|± | 2.87|
23
+ | | |acc_norm|25.00|± | 2.93|
24
+
25
+ Average: 23.76%
26
+
27
+ ### GPT4All
28
+ | Task |Version| Metric |Value| |Stderr|
29
+ |-------------|------:|--------|----:|---|-----:|
30
+ |arc_challenge| 0|acc |37.97|± | 1.42|
31
+ | | |acc_norm|40.61|± | 1.44|
32
+ |arc_easy | 0|acc |67.63|± | 0.96|
33
+ | | |acc_norm|65.82|± | 0.97|
34
+ |boolq | 1|acc |69.85|± | 0.80|
35
+ |hellaswag | 0|acc |52.39|± | 0.50|
36
+ | | |acc_norm|67.70|± | 0.47|
37
+ |openbookqa | 0|acc |25.40|± | 1.95|
38
+ | | |acc_norm|37.40|± | 2.17|
39
+ |piqa | 0|acc |71.71|± | 1.05|
40
+ | | |acc_norm|72.74|± | 1.04|
41
+ |winogrande | 0|acc |53.59|± | 1.40|
42
+
43
+ Average: 58.25%
44
+
45
+ ### TruthfulQA
46
+ | Task |Version|Metric|Value| |Stderr|
47
+ |-------------|------:|------|----:|---|-----:|
48
+ |truthfulqa_mc| 1|mc1 |28.76|± | 1.58|
49
+ | | |mc2 |44.47|± | 1.61|
50
+
51
+ Average: 44.47%
52
+
53
+ ### Bigbench
54
+ | Task |Version| Metric |Value| |Stderr|
55
+ |------------------------------------------------|------:|---------------------|----:|---|-----:|
56
+ |bigbench_causal_judgement | 0|multiple_choice_grade|51.58|± | 3.64|
57
+ |bigbench_date_understanding | 0|multiple_choice_grade|43.63|± | 2.59|
58
+ |bigbench_disambiguation_qa | 0|multiple_choice_grade|37.21|± | 3.02|
59
+ |bigbench_geometric_shapes | 0|multiple_choice_grade|10.03|± | 1.59|
60
+ | | |exact_str_match | 0.00|± | 0.00|
61
+ |bigbench_logical_deduction_five_objects | 0|multiple_choice_grade|23.80|± | 1.91|
62
+ |bigbench_logical_deduction_seven_objects | 0|multiple_choice_grade|18.00|± | 1.45|
63
+ |bigbench_logical_deduction_three_objects | 0|multiple_choice_grade|38.67|± | 2.82|
64
+ |bigbench_movie_recommendation | 0|multiple_choice_grade|22.60|± | 1.87|
65
+ |bigbench_navigate | 0|multiple_choice_grade|50.00|± | 1.58|
66
+ |bigbench_reasoning_about_colored_objects | 0|multiple_choice_grade|32.80|± | 1.05|
67
+ |bigbench_ruin_names | 0|multiple_choice_grade|25.67|± | 2.07|
68
+ |bigbench_salient_translation_error_detection | 0|multiple_choice_grade|19.24|± | 1.25|
69
+ |bigbench_snarks | 0|multiple_choice_grade|44.75|± | 3.71|
70
+ |bigbench_sports_understanding | 0|multiple_choice_grade|49.70|± | 1.59|
71
+ |bigbench_temporal_sequences | 0|multiple_choice_grade|24.60|± | 1.36|
72
+ |bigbench_tracking_shuffled_objects_five_objects | 0|multiple_choice_grade|19.20|± | 1.11|
73
+ |bigbench_tracking_shuffled_objects_seven_objects| 0|multiple_choice_grade|13.60|± | 0.82|
74
+ |bigbench_tracking_shuffled_objects_three_objects| 0|multiple_choice_grade|38.67|± | 2.82|
75
+
76
+ Average: 31.32%
77
+
78
+ Average score: 39.45%
79
+
80
+ Elapsed time: 02:46:40
notebooks/training.ipynb ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ef297e12",
6
+ "metadata": {},
7
+ "source": [
8
+ "# gemma-2b-orpo Training notebook\n",
9
+ "\n",
10
+ "gemma-2b-orpo is ORPO fine-tune of [google/gemma-2b](https://huggingface.co/google/gemma-2b) with\n",
11
+ "[`alvarobartt/dpo-mix-7k-simplified`](https://huggingface.co/datasets/alvarobartt/dpo-mix-7k-simplified).\n",
12
+ "\n",
13
+ "Some good resources:\n",
14
+ "- [HF Transformers Trainer docs](https://huggingface.co/docs/transformers/main_classes/trainer)\n",
15
+ "- [Docs on training with ORPO using HF TRL](https://huggingface.co/docs/trl/main/en/orpo_trainer)\n",
16
+ "- [TRL example script for ORPO](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py)\n",
17
+ "- [How to fine-tune Google Gemma with ChatML and Hugging Face TRL](https://www.philschmid.de/fine-tune-google-gemma)"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "id": "a34c07e3-396b-4a9f-83c1-ba9ae128c83a",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "! pip install git+https://github.com/huggingface/trl.git # install TRL from the main branch to use the ORPOTrainer\n",
28
+ "! pip install bitsandbytes accelerate\n",
29
+ "! pip install ninja packaging\n",
30
+ "! MAX_JOBS=6 pip install flash-attn --no-build-isolation --upgrade # flash-attn speeds up the training on compatible GPUs\n",
31
+ "! pip install wandb"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "594f15fa-38a3-4898-88f8-65eb4cf22531",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "# Login to the Hugging Face Hub to save the model\n",
42
+ "from huggingface_hub import login\n",
43
+ "\n",
44
+ "login(token=\"YOUR_TOKEN\")"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "id": "d155473a-b979-46fd-b858-04907f046a5e",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "# https://huggingface.co/docs/trl/main/en/orpo_trainer#trl.ORPOConfig\n",
55
+ "# https://www.philschmid.de/fine-tune-google-gemma\n",
56
+ "\n",
57
+ "from trl import ORPOConfig, ORPOTrainer\n",
58
+ "\n",
59
+ "# in the following config, we combine the usual HF Trainer args with the ORPOConfig args (beta)\n",
60
+ "\n",
61
+ "cfg = ORPOConfig(\n",
62
+ " output_dir='content/gemma-2b-orpo', # usual HF Trainer args: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.args\n",
63
+ " num_train_epochs=3, # number of training epochs\n",
64
+ " per_device_train_batch_size=2, # batch size per device during training\n",
65
+ " gradient_accumulation_steps=2, # number of steps before performing a backward/update pass\n",
66
+ " gradient_checkpointing=True, # use gradient checkpointing to save memory\n",
67
+ " optim=\"adamw_torch_fused\", # use fused adamw optimizer\n",
68
+ " logging_steps=20, # log every 20 steps\n",
69
+ " bf16=True, # use bfloat16 precision\n",
70
+ " tf32=True, # use tf32 \n",
71
+ " learning_rate=5e-5, # learning rate\n",
72
+ " warmup_ratio=0.1,\n",
73
+ " warmup_steps=100,\n",
74
+ " lr_scheduler_type=\"cosine\",\n",
75
+ " max_prompt_length=512,\n",
76
+ " remove_unused_columns=False,\n",
77
+ " max_length=1024,\n",
78
+ " beta=0.1, # ORPO beta\n",
79
+ " save_total_limit=3, # args related to saving the model...\n",
80
+ " save_strategy=\"epoch\",\n",
81
+ " push_to_hub=True, \n",
82
+ " report_to=['wandb'], # report metrics to Weights & Biases\n",
83
+ " hub_model_id='anakin87/gemma-2b-orpo',\n",
84
+ ")"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 5,
90
+ "id": "7f3fcb58-3cb5-4898-b4df-8023f85e9b1e",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "import torch\n",
95
+ "from transformers import AutoTokenizer, AutoModelForCausalLM"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "a8c20d77-38ee-4a29-9baf-ba75fd6c4c72",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "model_id = \"google/gemma-2b\"\n",
106
+ "tokenizer_id = \"philschmid/gemma-tokenizer-chatml\"\n",
107
+ "\n",
108
+ "\n",
109
+ "# Load model and tokenizer\n",
110
+ "model = AutoModelForCausalLM.from_pretrained(\n",
111
+ " model_id,\n",
112
+ " device_map=\"auto\",\n",
113
+ " torch_dtype=torch.bfloat16,\n",
114
+ " attn_implementation=\"flash_attention_2\",\n",
115
+ ")\n",
116
+ "tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)\n",
117
+ "tokenizer.padding_side = 'right' # to prevent warnings"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 7,
123
+ "id": "eabf18eb-8159-4712-8a4a-d69a29de794f",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "from datasets import load_dataset\n",
128
+ "import multiprocessing"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "5915fc40-96cd-4f86-8b7c-d394e801201a",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py\n",
139
+ "\n",
140
+ "ds = load_dataset(\"alvarobartt/dpo-mix-7k-simplified\")\n",
141
+ "\n",
142
+ "def process(row):\n",
143
+ " row[\"prompt\"] = tokenizer.apply_chat_template(row[\"prompt\"], tokenize=False)\n",
144
+ " row[\"chosen\"] = tokenizer.apply_chat_template(row[\"chosen\"], tokenize=False)\n",
145
+ " row[\"rejected\"] = tokenizer.apply_chat_template(row[\"rejected\"], tokenize=False)\n",
146
+ " return row\n",
147
+ "\n",
148
+ "ds = ds.map(\n",
149
+ " process,\n",
150
+ " num_proc=multiprocessing.cpu_count(),\n",
151
+ " load_from_cache_file=False,\n",
152
+ ")\n",
153
+ "train_dataset = ds[\"train\"]\n",
154
+ "eval_dataset = ds[\"test\"]"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": 9,
160
+ "id": "fff5ba61-7f23-4971-8472-cc339a2f803c",
161
+ "metadata": {},
162
+ "outputs": [
163
+ {
164
+ "data": {
165
+ "text/plain": [
166
+ "{'dataset': 'argilla/distilabel-intel-orca-dpo-pairs',\n",
167
+ " 'prompt': '<bos><|im_start|>user\\nQ:Question: how old julio cesar chavez when he fought de la hoya I found the following answer on Google: He holds records for most successful consecutive defenses of world titles (27), most title fights (37), most title-fight victories (31) and he is after Joe Louis with (23) for most title defenses won by knockout (21). Is that a correct answer? Yes or no.\\nA:<|im_end|>\\n',\n",
168
+ " 'chosen': \"<|im_start|>assistant\\n Yes, the information you found on Google is correct. Julio César Chávez holds several records related to world title defenses and victories, and he is considered one of the greatest boxers in history. Here is a detailed answer to your question:\\n\\nJulio César Chávez was born on July 12, 1962, in Ciudad Obregón, Sonora, Mexico. He began boxing at a young age and quickly made a name for himself in the sport, winning his first world title in 1984 when he defeated Mario Miranda for the WBC super featherweight championship.\\n\\nOver the course of his career, Chávez would go on to hold titles in three different weight classes (super featherweight, lightweight, and junior welterweight) and defend his titles a record 27 times, including 21 times by knockout. This record for most successful consecutive defenses of world titles has never been broken, and it is a testament to Chávez's skill, determination, and durability as a fighter.\\n\\nIn addition to his record for most successful title defenses, Chávez also holds the records for most title fights (37) and most title-fight victories (31). These records are also unbroken and demonstrate Chávez's consistent success and dominance in the ring.\\n\\nChávez's impressive record and achievements have earned him a place among the greatest boxers of all time, and he is often compared to other boxing legends such as Joe Louis, who holds the record for most title defenses won by knockout (23). While Chávez did not quite match Louis's record for most title defenses won by knockout, he is still widely regarded as one of the greatest boxers in history and a true legend of the sport.\\n\\nIn conclusion, the information you found on Google is correct: Julio César Chávez holds several records related to world title defenses and victories, and he is considered one of the greatest boxers in history. His impressive record and achievements have earned him a place among the greatest boxers of all time, and he will always be remembered as a true legend of the sport.<|im_end|>\\n<eos>\",\n",
169
+ " 'rejected': '<|im_start|>assistant\\nNo, the answer provided does not directly answer the question regarding the age of Julio Cesar Chavez when he fought Oscar De La Hoya. The provided information describes some general records held by Julio Cesar Chavez throughout his career. To answer your original question, let me provide the relevant information:\\n\\nJulio Cesar Chavez fought Oscar De La Hoya on June 7, 1996, in a match called \"Ultimate Glory.\" Chavez was born on July 12, 1962. To calculate his age at the time of the fight, we need to find the difference between the fight date and his birthdate.\\n\\nFrom July 12, 1962, to June 7, 1996, there are:\\n- 33 years (from 1962 to 1995)\\n- An additional year from his birthday in 1995 (July 12, 1995) to the fight date in 1996 (June 7, 1996), which is approximately 10 months and 26 days.\\n\\nTherefore, Julio Cesar Chavez was about 33 years and 10 months old when he fought Oscar De La Hoya.<|im_end|>\\n<eos>'}"
170
+ ]
171
+ },
172
+ "execution_count": 9,
173
+ "metadata": {},
174
+ "output_type": "execute_result"
175
+ }
176
+ ],
177
+ "source": [
178
+ "train_dataset[0]"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "e3d29888-f3db-46eb-8302-6dff2ebf27f3",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "import wandb\n",
189
+ "run = wandb.init(project=\"YOUR_PROJECT_NAME\")"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "7a92cb46-4d17-4910-8cf0-40970d5f7193",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "orpo_trainer = ORPOTrainer(\n",
200
+ " model=model,\n",
201
+ " args=cfg,\n",
202
+ " train_dataset=train_dataset,\n",
203
+ " eval_dataset=eval_dataset,\n",
204
+ " tokenizer=tokenizer\n",
205
+ ")"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "id": "15d24895-204c-4313-943a-d6c31fcde6e5",
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "orpo_trainer.train()"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "ade7f4ed-339b-4609-ab41-8aa736001474",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "orpo_trainer.push_to_hub()"
226
+ ]
227
+ }
228
+ ],
229
+ "metadata": {
230
+ "kernelspec": {
231
+ "display_name": "Python 3 (ipykernel)",
232
+ "language": "python",
233
+ "name": "python3"
234
+ },
235
+ "language_info": {
236
+ "codemirror_mode": {
237
+ "name": "ipython",
238
+ "version": 3
239
+ },
240
+ "file_extension": ".py",
241
+ "mimetype": "text/x-python",
242
+ "name": "python",
243
+ "nbconvert_exporter": "python",
244
+ "pygments_lexer": "ipython3",
245
+ "version": "3.10.12"
246
+ }
247
+ },
248
+ "nbformat": 4,
249
+ "nbformat_minor": 5
250
+ }
notebooks/usage.ipynb ADDED
The diff for this file is too large to render. See raw diff