tokestermw commited on
Commit
98c25e1
1 Parent(s): 68601ec

Add ORPO example and e2e test (#1572)

Browse files

* add example for mistral orpo

* sample_packing: false for orpo

* go to load_dataset (since load_rl_datasets require a transfom_fn, which only dpo uses currently)

.gitignore CHANGED
@@ -133,6 +133,7 @@ venv/
133
  ENV/
134
  env.bak/
135
  venv.bak/
 
136
 
137
  # Spyder project settings
138
  .spyderproject
 
133
  ENV/
134
  env.bak/
135
  venv.bak/
136
+ venv3.10/
137
 
138
  # Spyder project settings
139
  .spyderproject
docs/rlhf.qmd CHANGED
@@ -49,7 +49,7 @@ remove_unused_columns: false
49
  chat_template: chatml
50
  datasets:
51
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
52
- type: orpo.chat_template
53
  ```
54
 
55
  #### Using local dataset files
 
49
  chat_template: chatml
50
  datasets:
51
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
52
+ type: chat_template.argilla
53
  ```
54
 
55
  #### Using local dataset files
examples/mistral/mistral-qlora-orpo.yml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistralai/Mistral-7B-v0.1
2
+ model_type: MistralForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+
5
+ load_in_8bit: false
6
+ load_in_4bit: true
7
+ strict: false
8
+
9
+ rl: orpo
10
+ orpo_alpha: 0.1
11
+ remove_unused_columns: false
12
+
13
+ chat_template: chatml
14
+ datasets:
15
+ - path: argilla/ultrafeedback-binarized-preferences-cleaned
16
+ type: chat_template.argilla
17
+ dataset_prepared_path: last_run_prepared
18
+ val_set_size: 0.1
19
+ output_dir: ./mistral-qlora-orpo-out
20
+
21
+ adapter: qlora
22
+ lora_model_dir:
23
+
24
+ sequence_len: 4096
25
+ sample_packing: false
26
+ pad_to_sequence_len: true
27
+
28
+ lora_r: 32
29
+ lora_alpha: 16
30
+ lora_dropout: 0.05
31
+ lora_target_linear: true
32
+ lora_fan_in_fan_out:
33
+ lora_target_modules:
34
+ - gate_proj
35
+ - down_proj
36
+ - up_proj
37
+ - q_proj
38
+ - v_proj
39
+ - k_proj
40
+ - o_proj
41
+
42
+ wandb_project:
43
+ wandb_entity:
44
+ wandb_watch:
45
+ wandb_name:
46
+ wandb_log_model:
47
+
48
+ gradient_accumulation_steps: 4
49
+ micro_batch_size: 2
50
+ num_epochs: 1
51
+ optimizer: adamw_bnb_8bit
52
+ lr_scheduler: cosine
53
+ learning_rate: 0.0002
54
+
55
+ train_on_inputs: false
56
+ group_by_length: false
57
+ bf16: auto
58
+ fp16:
59
+ tf32: false
60
+
61
+ gradient_checkpointing: true
62
+ early_stopping_patience:
63
+ resume_from_checkpoint:
64
+ local_rank:
65
+ logging_steps: 1
66
+ xformers_attention:
67
+ flash_attention: true
68
+
69
+ loss_watchdog_threshold: 5.0
70
+ loss_watchdog_patience: 3
71
+
72
+ warmup_steps: 10
73
+ evals_per_epoch: 4
74
+ eval_table_size:
75
+ eval_max_new_tokens: 128
76
+ saves_per_epoch: 1
77
+ debug:
78
+ deepspeed:
79
+ weight_decay: 0.0
80
+ fsdp:
81
+ fsdp_config:
82
+ special_tokens:
tests/e2e/test_dpo.py CHANGED
@@ -158,3 +158,50 @@ class TestDPOLlamaLora(unittest.TestCase):
158
 
159
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
160
  assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
160
  assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
161
+
162
+ @with_temp_dir
163
+ def test_orpo_lora(self, temp_dir):
164
+ # pylint: disable=duplicate-code
165
+ cfg = DictDefault(
166
+ {
167
+ "base_model": "JackFram/llama-68m",
168
+ "tokenizer_type": "LlamaTokenizer",
169
+ "sequence_len": 1024,
170
+ "load_in_8bit": True,
171
+ "adapter": "lora",
172
+ "lora_r": 64,
173
+ "lora_alpha": 32,
174
+ "lora_dropout": 0.1,
175
+ "lora_target_linear": True,
176
+ "special_tokens": {},
177
+ "rl": "orpo",
178
+ "orpo_alpha": 0.1,
179
+ "remove_unused_columns": False,
180
+ "chat_template": "chatml",
181
+ "datasets": [
182
+ {
183
+ "path": "argilla/ultrafeedback-binarized-preferences-cleaned",
184
+ "type": "chat_template.argilla",
185
+ "split": "train",
186
+ },
187
+ ],
188
+ "num_epochs": 1,
189
+ "micro_batch_size": 4,
190
+ "gradient_accumulation_steps": 1,
191
+ "output_dir": temp_dir,
192
+ "learning_rate": 0.00001,
193
+ "optimizer": "paged_adamw_8bit",
194
+ "lr_scheduler": "cosine",
195
+ "max_steps": 20,
196
+ "save_steps": 10,
197
+ "warmup_steps": 5,
198
+ "gradient_checkpointing": True,
199
+ "gradient_checkpointing_kwargs": {"use_reentrant": True},
200
+ }
201
+ )
202
+ normalize_config(cfg)
203
+ cli_args = TrainerCliArgs()
204
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
205
+
206
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
207
+ assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()