omasteam commited on
Commit
f4ab725
·
verified ·
1 Parent(s): 2efb897

Upload 36 files

Browse files
Files changed (37) hide show
  1. .gitattributes +2 -0
  2. unsloth_compiled_cache/UnslothAlignPropTrainer.py +678 -0
  3. unsloth_compiled_cache/UnslothBCOTrainer.py +1857 -0
  4. unsloth_compiled_cache/UnslothCPOTrainer.py +1618 -0
  5. unsloth_compiled_cache/UnslothDDPOTrainer.py +914 -0
  6. unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
  7. unsloth_compiled_cache/UnslothGKDTrainer.py +880 -0
  8. unsloth_compiled_cache/UnslothGRPOTrainer.py +0 -0
  9. unsloth_compiled_cache/UnslothIterativeSFTTrainer.py +924 -0
  10. unsloth_compiled_cache/UnslothKTOTrainer.py +0 -0
  11. unsloth_compiled_cache/UnslothNashMDTrainer.py +1019 -0
  12. unsloth_compiled_cache/UnslothORPOTrainer.py +1574 -0
  13. unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1327 -0
  14. unsloth_compiled_cache/UnslothPPOTrainer.py +1319 -0
  15. unsloth_compiled_cache/UnslothPRMTrainer.py +848 -0
  16. unsloth_compiled_cache/UnslothRLOOTrainer.py +1174 -0
  17. unsloth_compiled_cache/UnslothRewardTrainer.py +866 -0
  18. unsloth_compiled_cache/UnslothSFTTrainer.py +1253 -0
  19. unsloth_compiled_cache/UnslothXPOTrainer.py +1062 -0
  20. unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-312.pyc +0 -0
  21. unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  22. unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  23. unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-312.pyc +0 -0
  24. unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +3 -0
  25. unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  26. unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +3 -0
  27. unsloth_compiled_cache/__pycache__/UnslothIterativeSFTTrainer.cpython-312.pyc +0 -0
  28. unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  29. unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  30. unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  31. unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  32. unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  33. unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  34. unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  35. unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  36. unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  37. unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
37
+ unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
unsloth_compiled_cache/UnslothAlignPropTrainer.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, Path, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warnings)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothAlignPropConfig(AlignPropConfig):
53
+ """
54
+
55
+ Configuration class for the [`AlignPropTrainer`].
56
+
57
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
58
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
+ command line.
60
+
61
+ Parameters:
62
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
63
+ Name of this experiment (defaults to the file name without the extension).
64
+ run_name (`str`, *optional*, defaults to `""`):
65
+ Name of this run.
66
+ seed (`int`, *optional*, defaults to `0`):
67
+ Random seed for reproducibility.
68
+ log_with (`str` or `None`, *optional*, defaults to `None`):
69
+ Log with either `"wandb"` or `"tensorboard"`. Check
70
+ [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
71
+ log_image_freq (`int`, *optional*, defaults to `1`):
72
+ Frequency for logging images.
73
+ tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
74
+ Keyword arguments for the tracker (e.g., `wandb_project`).
75
+ accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
76
+ Keyword arguments for the accelerator.
77
+ project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
78
+ Keyword arguments for the accelerator project config (e.g., `logging_dir`).
79
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
80
+ Name of project to use for tracking.
81
+ logdir (`str`, *optional*, defaults to `"logs"`):
82
+ Top-level logging directory for checkpoint saving.
83
+ num_epochs (`int`, *optional*, defaults to `100`):
84
+ Number of epochs to train.
85
+ save_freq (`int`, *optional*, defaults to `1`):
86
+ Number of epochs between saving model checkpoints.
87
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
88
+ Number of checkpoints to keep before overwriting old ones.
89
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
90
+ Mixed precision training.
91
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
92
+ Allow `tf32` on Ampere GPUs.
93
+ resume_from (`str`, *optional*, defaults to `""`):
94
+ Path to resume training from a checkpoint.
95
+ sample_num_steps (`int`, *optional*, defaults to `50`):
96
+ Number of sampler inference steps.
97
+ sample_eta (`float`, *optional*, defaults to `1.0`):
98
+ Eta parameter for the DDIM sampler.
99
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
100
+ Classifier-free guidance weight.
101
+ train_batch_size (`int`, *optional*, defaults to `1`):
102
+ Batch size for training.
103
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
104
+ Whether to use the 8bit Adam optimizer from `bitsandbytes`.
105
+ train_learning_rate (`float`, *optional*, defaults to `1e-3`):
106
+ Learning rate.
107
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
108
+ Beta1 for Adam optimizer.
109
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
110
+ Beta2 for Adam optimizer.
111
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
112
+ Weight decay for Adam optimizer.
113
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
114
+ Epsilon value for Adam optimizer.
115
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
116
+ Number of gradient accumulation steps.
117
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
118
+ Maximum gradient norm for gradient clipping.
119
+ negative_prompts (`str` or `None`, *optional*, defaults to `None`):
120
+ Comma-separated list of prompts to use as negative examples.
121
+ truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
122
+ If `True`, randomized truncation to different diffusion timesteps is used.
123
+ truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
124
+ Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
125
+ truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
126
+ Range of diffusion timesteps for randomized truncated backpropagation.
127
+ push_to_hub (`bool`, *optional*, defaults to `False`):
128
+ Whether to push the final model to the Hub.
129
+
130
+ """
131
+ vllm_sampling_params: Optional[Any] = field(
132
+ default = None,
133
+ metadata = {'help': 'vLLM SamplingParams'},
134
+ )
135
+ unsloth_num_chunks : Optional[int] = field(
136
+ default = -1,
137
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
138
+ )
139
+
140
+ def __init__(
141
+ self,
142
+ exp_name = 'colab_kernel_launcher',
143
+ run_name = '',
144
+ seed = 3407,
145
+ log_with = None,
146
+ log_image_freq = 1,
147
+ tracker_project_name = 'trl',
148
+ logdir = 'logs',
149
+ num_epochs = 100,
150
+ save_freq = 1,
151
+ num_checkpoint_limit = 5,
152
+ mixed_precision = 'fp16',
153
+ allow_tf32 = True,
154
+ resume_from = '',
155
+ sample_num_steps = 50,
156
+ sample_eta = 1.0,
157
+ sample_guidance_scale = 5.0,
158
+ train_batch_size = 1,
159
+ train_use_8bit_adam = False,
160
+ train_learning_rate = 5e-05,
161
+ train_adam_beta1 = 0.9,
162
+ train_adam_beta2 = 0.999,
163
+ train_adam_weight_decay = 0.01,
164
+ train_adam_epsilon = 1e-08,
165
+ train_gradient_accumulation_steps = 2,
166
+ train_max_grad_norm = 1.0,
167
+ negative_prompts = None,
168
+ truncated_backprop_rand = True,
169
+ truncated_backprop_timestep = 49,
170
+ push_to_hub = False,
171
+ vllm_sampling_params = None,
172
+ unsloth_num_chunks = -1,
173
+
174
+ **kwargs,
175
+ ):
176
+
177
+ super().__init__(
178
+ exp_name = exp_name,
179
+ run_name = run_name,
180
+ seed = seed,
181
+ log_with = log_with,
182
+ log_image_freq = log_image_freq,
183
+ tracker_project_name = tracker_project_name,
184
+ logdir = logdir,
185
+ num_epochs = num_epochs,
186
+ save_freq = save_freq,
187
+ num_checkpoint_limit = num_checkpoint_limit,
188
+ mixed_precision = mixed_precision,
189
+ allow_tf32 = allow_tf32,
190
+ resume_from = resume_from,
191
+ sample_num_steps = sample_num_steps,
192
+ sample_eta = sample_eta,
193
+ sample_guidance_scale = sample_guidance_scale,
194
+ train_batch_size = train_batch_size,
195
+ train_use_8bit_adam = train_use_8bit_adam,
196
+ train_learning_rate = train_learning_rate,
197
+ train_adam_beta1 = train_adam_beta1,
198
+ train_adam_beta2 = train_adam_beta2,
199
+ train_adam_weight_decay = train_adam_weight_decay,
200
+ train_adam_epsilon = train_adam_epsilon,
201
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
202
+ train_max_grad_norm = train_max_grad_norm,
203
+ negative_prompts = negative_prompts,
204
+ truncated_backprop_rand = truncated_backprop_rand,
205
+ truncated_backprop_timestep = truncated_backprop_timestep,
206
+ push_to_hub = push_to_hub,**kwargs)
207
+ self.vllm_sampling_params = vllm_sampling_params
208
+ self.unsloth_num_chunks = unsloth_num_chunks
209
+
210
+ pass
211
+
212
+ class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
213
+ """"""
214
+
215
+ _tag_names = ["trl", "alignprop"]
216
+
217
+ def __init__(
218
+ self,
219
+ config: AlignPropConfig,
220
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
221
+ prompt_function: Callable[[], tuple[str, Any]],
222
+ sd_pipeline: DDPOStableDiffusionPipeline,
223
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
224
+ ):
225
+ warnings.warn(
226
+ "AlignPropTrainer is deprecated and will be removed in version 0.23.0.",
227
+ DeprecationWarning,
228
+ )
229
+ if image_samples_hook is None:
230
+ warnings.warn("No image_samples_hook provided; no images will be logged")
231
+
232
+ self.prompt_fn = prompt_function
233
+ self.reward_fn = reward_function
234
+ self.config = config
235
+ self.image_samples_callback = image_samples_hook
236
+
237
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
238
+
239
+ if self.config.resume_from:
240
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
241
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
242
+ # get the most recent checkpoint in this directory
243
+ checkpoints = list(
244
+ filter(
245
+ lambda x: "checkpoint_" in x,
246
+ os.listdir(self.config.resume_from),
247
+ )
248
+ )
249
+ if len(checkpoints) == 0:
250
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
251
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
252
+ self.config.resume_from = os.path.join(
253
+ self.config.resume_from,
254
+ f"checkpoint_{checkpoint_numbers[-1]}",
255
+ )
256
+
257
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
258
+
259
+ self.accelerator = Accelerator(
260
+ log_with=self.config.log_with,
261
+ mixed_precision=self.config.mixed_precision,
262
+ project_config=accelerator_project_config,
263
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
264
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
265
+ # the total number of optimizer steps to accumulate across.
266
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
267
+ **self.config.accelerator_kwargs,
268
+ )
269
+
270
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
271
+
272
+ if self.accelerator.is_main_process:
273
+ self.accelerator.init_trackers(
274
+ self.config.tracker_project_name,
275
+ config=dict(alignprop_trainer_config=config.to_dict())
276
+ if not is_using_tensorboard
277
+ else config.to_dict(),
278
+ init_kwargs=self.config.tracker_kwargs,
279
+ )
280
+
281
+ logger.info(f"\n{config}")
282
+
283
+ set_seed(self.config.seed, device_specific=True)
284
+
285
+ self.sd_pipeline = sd_pipeline
286
+
287
+ self.sd_pipeline.set_progress_bar_config(
288
+ position=1,
289
+ disable=not self.accelerator.is_local_main_process,
290
+ leave=False,
291
+ desc="Timestep",
292
+ dynamic_ncols=True,
293
+ )
294
+
295
+ # For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
296
+ # as these weights are only used for inference, keeping weights in full precision is not required.
297
+ if self.accelerator.mixed_precision == "fp16":
298
+ inference_dtype = torch.float16
299
+ elif self.accelerator.mixed_precision == "bf16":
300
+ inference_dtype = torch.bfloat16
301
+ else:
302
+ inference_dtype = torch.float32
303
+
304
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
305
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
306
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
307
+
308
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
309
+
310
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
311
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
312
+
313
+ # Enable TF32 for faster training on Ampere GPUs,
314
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
315
+ if self.config.allow_tf32:
316
+ torch.backends.cuda.matmul.allow_tf32 = True
317
+
318
+ self.optimizer = self._setup_optimizer(
319
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
320
+ )
321
+
322
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
323
+ self.sd_pipeline.tokenizer(
324
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
325
+ return_tensors="pt",
326
+ padding="max_length",
327
+ truncation=True,
328
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
329
+ ).input_ids.to(self.accelerator.device)
330
+ )[0]
331
+
332
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
333
+ # more memory
334
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
335
+
336
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
337
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
338
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
339
+ else:
340
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
341
+
342
+ if config.resume_from:
343
+ logger.info(f"Resuming from {config.resume_from}")
344
+ self.accelerator.load_state(config.resume_from)
345
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
346
+ else:
347
+ self.first_epoch = 0
348
+
349
+ def compute_rewards(self, prompt_image_pairs):
350
+ reward, reward_metadata = self.reward_fn(
351
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
352
+ )
353
+ return reward
354
+
355
+ def step(self, epoch: int, global_step: int):
356
+ """
357
+ Perform a single step of training.
358
+
359
+ Args:
360
+ epoch (int): The current epoch.
361
+ global_step (int): The current global step.
362
+
363
+ Side Effects:
364
+ - Model weights are updated
365
+ - Logs the statistics to the accelerator trackers.
366
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
367
+ and the accelerator tracker.
368
+
369
+ Returns:
370
+ global_step (int): The updated global step.
371
+ """
372
+ info = defaultdict(list)
373
+
374
+ self.sd_pipeline.unet.train()
375
+
376
+ for _ in range(self.config.train_gradient_accumulation_steps):
377
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
378
+ prompt_image_pairs = self._generate_samples(
379
+ batch_size=self.config.train_batch_size,
380
+ )
381
+
382
+ rewards = self.compute_rewards(prompt_image_pairs)
383
+
384
+ prompt_image_pairs["rewards"] = rewards
385
+
386
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
387
+
388
+ loss = self.calculate_loss(rewards)
389
+
390
+ self.accelerator.backward(loss)
391
+
392
+ if self.accelerator.sync_gradients:
393
+ self.accelerator.clip_grad_norm_(
394
+ self.trainable_layers.parameters()
395
+ if not isinstance(self.trainable_layers, list)
396
+ else self.trainable_layers,
397
+ self.config.train_max_grad_norm,
398
+ )
399
+
400
+ self.optimizer.step()
401
+ self.optimizer.zero_grad()
402
+
403
+ info["reward_mean"].append(rewards_vis.mean())
404
+ info["reward_std"].append(rewards_vis.std())
405
+ info["loss"].append(loss.item())
406
+
407
+ # Checks if the accelerator has performed an optimization step behind the scenes
408
+ if self.accelerator.sync_gradients:
409
+ # log training-related stuff
410
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
411
+ info = self.accelerator.reduce(info, reduction="mean")
412
+ info.update({"epoch": epoch})
413
+ self.accelerator.log(info, step=global_step)
414
+ global_step += 1
415
+ info = defaultdict(list)
416
+ else:
417
+ raise ValueError(
418
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
419
+ )
420
+ # Logs generated images
421
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
422
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
423
+
424
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
425
+ self.accelerator.save_state()
426
+
427
+ return global_step
428
+
429
+ def calculate_loss(self, rewards):
430
+ """
431
+ Calculate the loss for a batch of an unpacked sample
432
+
433
+ Args:
434
+ rewards (torch.Tensor):
435
+ Differentiable reward scalars for each generated image, shape: [batch_size]
436
+
437
+ Returns:
438
+ loss (torch.Tensor) (all of these are of shape (1,))
439
+ """
440
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
441
+ loss = 10.0 - (rewards).mean()
442
+ return loss
443
+
444
+ def loss(
445
+ self,
446
+ advantages: torch.Tensor,
447
+ clip_range: float,
448
+ ratio: torch.Tensor,
449
+ ):
450
+ unclipped_loss = -advantages * ratio
451
+ clipped_loss = -advantages * torch.clamp(
452
+ ratio,
453
+ 1.0 - clip_range,
454
+ 1.0 + clip_range,
455
+ )
456
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
457
+
458
+ def _setup_optimizer(self, trainable_layers_parameters):
459
+ if self.config.train_use_8bit_adam:
460
+ import bitsandbytes
461
+
462
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
463
+ else:
464
+ optimizer_cls = torch.optim.AdamW
465
+
466
+ return optimizer_cls(
467
+ trainable_layers_parameters,
468
+ lr=self.config.train_learning_rate,
469
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
470
+ weight_decay=self.config.train_adam_weight_decay,
471
+ eps=self.config.train_adam_epsilon,
472
+ )
473
+
474
+ def _save_model_hook(self, models, weights, output_dir):
475
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
476
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
477
+
478
+ def _load_model_hook(self, models, input_dir):
479
+ self.sd_pipeline.load_checkpoint(models, input_dir)
480
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
481
+
482
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
483
+ """
484
+ Generate samples from the model
485
+
486
+ Args:
487
+ batch_size (int): Batch size to use for sampling
488
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
489
+
490
+ Returns:
491
+ prompt_image_pairs (dict[Any])
492
+ """
493
+ prompt_image_pairs = {}
494
+
495
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
496
+
497
+ if prompts is None:
498
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
499
+ else:
500
+ prompt_metadata = [{} for _ in range(batch_size)]
501
+
502
+ prompt_ids = self.sd_pipeline.tokenizer(
503
+ prompts,
504
+ return_tensors="pt",
505
+ padding="max_length",
506
+ truncation=True,
507
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
508
+ ).input_ids.to(self.accelerator.device)
509
+
510
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
511
+
512
+ if with_grad:
513
+ sd_output = self.sd_pipeline.rgb_with_grad(
514
+ prompt_embeds=prompt_embeds,
515
+ negative_prompt_embeds=sample_neg_prompt_embeds,
516
+ num_inference_steps=self.config.sample_num_steps,
517
+ guidance_scale=self.config.sample_guidance_scale,
518
+ eta=self.config.sample_eta,
519
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
520
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
521
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
522
+ output_type="pt",
523
+ )
524
+ else:
525
+ sd_output = self.sd_pipeline(
526
+ prompt_embeds=prompt_embeds,
527
+ negative_prompt_embeds=sample_neg_prompt_embeds,
528
+ num_inference_steps=self.config.sample_num_steps,
529
+ guidance_scale=self.config.sample_guidance_scale,
530
+ eta=self.config.sample_eta,
531
+ output_type="pt",
532
+ )
533
+
534
+ images = sd_output.images
535
+
536
+ prompt_image_pairs["images"] = images
537
+ prompt_image_pairs["prompts"] = prompts
538
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
539
+
540
+ return prompt_image_pairs
541
+
542
+ def train(self, epochs: Optional[int] = None):
543
+ """
544
+ Train the model for a given number of epochs
545
+ """
546
+ global_step = 0
547
+ if epochs is None:
548
+ epochs = self.config.num_epochs
549
+ for epoch in range(self.first_epoch, epochs):
550
+ global_step = self.step(epoch, global_step)
551
+
552
+ def _save_pretrained(self, save_directory):
553
+ self.sd_pipeline.save_pretrained(save_directory)
554
+ self.create_model_card()
555
+
556
+ # Ensure the model card is saved along with the checkpoint
557
+ def _save_checkpoint(self, model, trial):
558
+ if self.args.hub_model_id is None:
559
+ model_name = Path(self.args.output_dir).name
560
+ else:
561
+ model_name = self.args.hub_model_id.split("/")[-1]
562
+ self.create_model_card(model_name=model_name)
563
+ super()._save_checkpoint(model, trial)
564
+
565
+ def create_model_card(
566
+ self,
567
+ model_name: Optional[str] = None,
568
+ dataset_name: Optional[str] = None,
569
+ tags: Union[str, list[str], None] = None,
570
+ ):
571
+ """
572
+ Creates a draft of a model card using the information available to the `Trainer`.
573
+
574
+ Args:
575
+ model_name (`str` or `None`, *optional*, defaults to `None`):
576
+ Name of the model.
577
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
578
+ Name of the dataset used for training.
579
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
580
+ Tags to be associated with the model card.
581
+ """
582
+ if not self.is_world_process_zero():
583
+ return
584
+
585
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
586
+ base_model = self.model.config._name_or_path
587
+ else:
588
+ base_model = None
589
+
590
+ # normalize `tags` to a mutable set
591
+ if tags is None:
592
+ tags = set()
593
+ elif isinstance(tags, str):
594
+ tags = {tags}
595
+ else:
596
+ tags = set(tags)
597
+
598
+ if hasattr(self.model.config, "unsloth_version"):
599
+ tags.add("unsloth")
600
+
601
+ tags.update(self._tag_names)
602
+
603
+ citation = textwrap.dedent("""\
604
+ @article{prabhudesai2024aligning,
605
+ title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
606
+ author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
607
+ year = 2024,
608
+ eprint = {arXiv:2310.03739}
609
+ }""")
610
+
611
+ model_card = generate_model_card(
612
+ base_model=base_model,
613
+ model_name=model_name,
614
+ hub_model_id=self.hub_model_id,
615
+ dataset_name=dataset_name,
616
+ tags=tags,
617
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
618
+ comet_url=get_comet_experiment_url(),
619
+ trainer_name="AlignProp",
620
+ trainer_citation=citation,
621
+ paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
622
+ paper_id="2310.03739",
623
+ )
624
+
625
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
626
+ class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
627
+ """
628
+
629
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is
630
+ heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based
631
+ pipelines are supported
632
+
633
+ Attributes:
634
+ config (`AlignPropConfig`):
635
+ Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
636
+ reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
637
+ Reward function to be used
638
+ prompt_function (`Callable[[], tuple[str, Any]]`):
639
+ Function to generate prompts to guide model
640
+ sd_pipeline (`DDPOStableDiffusionPipeline`):
641
+ Stable Diffusion pipeline to be used for training.
642
+ image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
643
+ Hook to be called to log images
644
+
645
+ """
646
+ def __init__(
647
+ self,
648
+ config,
649
+ reward_function,
650
+ prompt_function,
651
+ sd_pipeline,
652
+ image_samples_hook = None,
653
+ **kwargs
654
+ ):
655
+ if args is None: args = UnslothAlignPropConfig()
656
+ other_metrics = []
657
+
658
+ from unsloth_zoo.logging_utils import PatchRLStatistics
659
+ PatchRLStatistics('alignprop_trainer', other_metrics)
660
+
661
+ super().__init__(
662
+ config = config,
663
+ reward_function = reward_function,
664
+ prompt_function = prompt_function,
665
+ sd_pipeline = sd_pipeline,
666
+ image_samples_hook = image_samples_hook,**kwargs)
667
+
668
+ pass
669
+
670
+
671
+ if hasattr(logger, "addFilter"):
672
+ import logging
673
+ class HideLoggingMessage(logging.Filter):
674
+ def __init__(self, text): self.text = text
675
+ def filter(self, x): return not (self.text in x.getMessage())
676
+ pass
677
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
678
+
unsloth_compiled_cache/UnslothBCOTrainer.py ADDED
@@ -0,0 +1,1857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, autocast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_joblib_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, joblib, log_table_to_comet_experiment, logger, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, logger, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothBCOConfig(BCOConfig):
53
+ """
54
+
55
+ Configuration class for the [`BCOTrainer`].
56
+
57
+ This class includes only the parameters that are specific to BCO training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
59
+ differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
67
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
68
+ to use the default data collator.
69
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
70
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
71
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
72
+ Maximum length of the completion. This argument is required if you want to use the default data collator
73
+ and your model is an encoder-decoder.
74
+ beta (`float`, *optional*, defaults to `0.1`):
75
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
76
+ reference model.
77
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
78
+ Label pad token id. This argument is required if you want to use the default data collator.
79
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
80
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
81
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
82
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
83
+ This argument is required if you want to use the default data collator.
84
+ disable_dropout (`bool`, *optional*, defaults to `True`):
85
+ Whether to disable dropout in the model and reference model.
86
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
87
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
88
+ during evaluation.
89
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
90
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
91
+ you need to specify if the model returned by the callable is an encoder-decoder model.
92
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
93
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
94
+ useful when training without the reference model to reduce the total GPU memory needed.
95
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
96
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
97
+ string.
98
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
99
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
100
+ from a string.
101
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
102
+ Number of processes to use for processing the dataset.
103
+ prompt_sample_size (`int`, *optional*, defaults to `1024`):
104
+ Number of prompts that are fed to density ratio classifier.
105
+ min_density_ratio (`float`, *optional*, defaults to `0.5`):
106
+ Minimum value of the density ratio. The estimated density ratio is clamped to this value.
107
+ max_density_ratio (`float`, *optional*, defaults to `10.0`):
108
+ Maximum value of the density ratio. The estimated density ratio is clamped to this value.
109
+
110
+ """
111
+ vllm_sampling_params: Optional[Any] = field(
112
+ default = None,
113
+ metadata = {'help': 'vLLM SamplingParams'},
114
+ )
115
+ unsloth_num_chunks : Optional[int] = field(
116
+ default = -1,
117
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
118
+ )
119
+ max_seq_length : Optional[int] = field(
120
+ default = None,
121
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
122
+ )
123
+ def __init__(
124
+ self,
125
+ output_dir = None,
126
+ overwrite_output_dir = None,
127
+ do_train = False,
128
+ do_eval = False,
129
+ do_predict = False,
130
+ eval_strategy = 'no',
131
+ prediction_loss_only = False,
132
+ per_device_train_batch_size = 4,
133
+ per_device_eval_batch_size = 4,
134
+ per_gpu_train_batch_size = None,
135
+ per_gpu_eval_batch_size = None,
136
+ gradient_accumulation_steps = 2,
137
+ eval_accumulation_steps = 2,
138
+ eval_delay = 0,
139
+ torch_empty_cache_steps = 250,
140
+ learning_rate = 5e-05,
141
+ weight_decay = 0.01,
142
+ adam_beta1 = 0.9,
143
+ adam_beta2 = 0.999,
144
+ adam_epsilon = 1e-08,
145
+ max_grad_norm = 1.0,
146
+ num_train_epochs = 3.0,
147
+ max_steps = -1,
148
+ lr_scheduler_type = 'linear',
149
+ warmup_ratio = 0.1,
150
+ warmup_steps = 0,
151
+ log_level = 'passive',
152
+ log_level_replica = 'warning',
153
+ log_on_each_node = True,
154
+ logging_dir = None,
155
+ logging_strategy = 'steps',
156
+ logging_first_step = False,
157
+ logging_steps = 1,
158
+ logging_nan_inf_filter = False,
159
+ save_strategy = 'steps',
160
+ save_steps = 500,
161
+ save_total_limit = None,
162
+ save_safetensors = True,
163
+ save_on_each_node = False,
164
+ save_only_model = False,
165
+ restore_callback_states_from_checkpoint = False,
166
+ no_cuda = False,
167
+ use_cpu = False,
168
+ use_mps_device = False,
169
+ seed = 3407,
170
+ data_seed = 3407,
171
+ jit_mode_eval = False,
172
+ use_ipex = False,
173
+ bf16 = False,
174
+ fp16 = False,
175
+ fp16_opt_level = 'O1',
176
+ half_precision_backend = 'auto',
177
+ bf16_full_eval = False,
178
+ fp16_full_eval = False,
179
+ tf32 = None,
180
+ local_rank = -1,
181
+ ddp_backend = None,
182
+ tpu_num_cores = None,
183
+ tpu_metrics_debug = False,
184
+ debug = '',
185
+ dataloader_drop_last = False,
186
+ eval_steps = None,
187
+ dataloader_num_workers = 0,
188
+ dataloader_prefetch_factor = None,
189
+ past_index = -1,
190
+ run_name = None,
191
+ disable_tqdm = None,
192
+ remove_unused_columns = True,
193
+ label_names = None,
194
+ load_best_model_at_end = False,
195
+ metric_for_best_model = None,
196
+ greater_is_better = None,
197
+ ignore_data_skip = False,
198
+ fsdp = '',
199
+ fsdp_min_num_params = 0,
200
+ fsdp_config = None,
201
+ fsdp_transformer_layer_cls_to_wrap = None,
202
+ accelerator_config = None,
203
+ deepspeed = None,
204
+ label_smoothing_factor = 0.0,
205
+ optim = 'adamw_8bit',
206
+ optim_args = None,
207
+ adafactor = False,
208
+ group_by_length = False,
209
+ length_column_name = 'length',
210
+ report_to = None,
211
+ ddp_find_unused_parameters = None,
212
+ ddp_bucket_cap_mb = None,
213
+ ddp_broadcast_buffers = None,
214
+ dataloader_pin_memory = True,
215
+ dataloader_persistent_workers = False,
216
+ skip_memory_metrics = True,
217
+ use_legacy_prediction_loop = False,
218
+ push_to_hub = False,
219
+ resume_from_checkpoint = None,
220
+ hub_model_id = None,
221
+ hub_strategy = 'every_save',
222
+ hub_token = None,
223
+ hub_private_repo = None,
224
+ hub_always_push = False,
225
+ hub_revision = None,
226
+ gradient_checkpointing = False,
227
+ gradient_checkpointing_kwargs = None,
228
+ include_inputs_for_metrics = False,
229
+ eval_do_concat_batches = True,
230
+ fp16_backend = 'auto',
231
+ push_to_hub_model_id = None,
232
+ push_to_hub_organization = None,
233
+ push_to_hub_token = None,
234
+ mp_parameters = '',
235
+ auto_find_batch_size = True,
236
+ full_determinism = False,
237
+ torchdynamo = None,
238
+ ray_scope = 'last',
239
+ ddp_timeout = 1800,
240
+ torch_compile = False,
241
+ torch_compile_backend = None,
242
+ torch_compile_mode = None,
243
+ include_tokens_per_second = False,
244
+ include_num_input_tokens_seen = False,
245
+ neftune_noise_alpha = None,
246
+ optim_target_modules = None,
247
+ batch_eval_metrics = False,
248
+ eval_on_start = False,
249
+ use_liger_kernel = False,
250
+ liger_kernel_config = None,
251
+ eval_use_gather_object = False,
252
+ average_tokens_across_devices = True,
253
+ max_length = 1024,
254
+ max_prompt_length = 512,
255
+ max_completion_length = None,
256
+ beta = 0.1,
257
+ label_pad_token_id = -100,
258
+ padding_value = None,
259
+ truncation_mode = 'keep_end',
260
+ disable_dropout = True,
261
+ generate_during_eval = False,
262
+ is_encoder_decoder = None,
263
+ precompute_ref_log_probs = False,
264
+ model_init_kwargs = None,
265
+ ref_model_init_kwargs = None,
266
+ dataset_num_proc = None,
267
+ prompt_sample_size = 1024,
268
+ min_density_ratio = 0.5,
269
+ max_density_ratio = 10.0,
270
+ vllm_sampling_params = None,
271
+ unsloth_num_chunks = -1,
272
+ max_seq_length = None,
273
+ **kwargs,
274
+ ):
275
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
276
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
277
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
278
+ output_dir = 'unsloth_training_checkpoints'
279
+ save_strategy = 'no'
280
+ if dataset_num_proc is None:
281
+ from multiprocessing import cpu_count
282
+ dataset_num_proc = min(cpu_count()*2, 2)
283
+
284
+ super().__init__(
285
+ output_dir = output_dir,
286
+ overwrite_output_dir = overwrite_output_dir,
287
+ do_train = do_train,
288
+ do_eval = do_eval,
289
+ do_predict = do_predict,
290
+ eval_strategy = eval_strategy,
291
+ prediction_loss_only = prediction_loss_only,
292
+ per_device_train_batch_size = per_device_train_batch_size,
293
+ per_device_eval_batch_size = per_device_eval_batch_size,
294
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
295
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
296
+ gradient_accumulation_steps = gradient_accumulation_steps,
297
+ eval_accumulation_steps = eval_accumulation_steps,
298
+ eval_delay = eval_delay,
299
+ torch_empty_cache_steps = torch_empty_cache_steps,
300
+ learning_rate = learning_rate,
301
+ weight_decay = weight_decay,
302
+ adam_beta1 = adam_beta1,
303
+ adam_beta2 = adam_beta2,
304
+ adam_epsilon = adam_epsilon,
305
+ max_grad_norm = max_grad_norm,
306
+ num_train_epochs = num_train_epochs,
307
+ max_steps = max_steps,
308
+ lr_scheduler_type = lr_scheduler_type,
309
+ warmup_ratio = warmup_ratio,
310
+ warmup_steps = warmup_steps,
311
+ log_level = log_level,
312
+ log_level_replica = log_level_replica,
313
+ log_on_each_node = log_on_each_node,
314
+ logging_dir = logging_dir,
315
+ logging_strategy = logging_strategy,
316
+ logging_first_step = logging_first_step,
317
+ logging_steps = logging_steps,
318
+ logging_nan_inf_filter = logging_nan_inf_filter,
319
+ save_strategy = save_strategy,
320
+ save_steps = save_steps,
321
+ save_total_limit = save_total_limit,
322
+ save_safetensors = save_safetensors,
323
+ save_on_each_node = save_on_each_node,
324
+ save_only_model = save_only_model,
325
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
326
+ no_cuda = no_cuda,
327
+ use_cpu = use_cpu,
328
+ use_mps_device = use_mps_device,
329
+ seed = seed,
330
+ data_seed = data_seed,
331
+ jit_mode_eval = jit_mode_eval,
332
+ use_ipex = use_ipex,
333
+ bf16 = bf16,
334
+ fp16 = fp16,
335
+ fp16_opt_level = fp16_opt_level,
336
+ half_precision_backend = half_precision_backend,
337
+ bf16_full_eval = bf16_full_eval,
338
+ fp16_full_eval = fp16_full_eval,
339
+ tf32 = tf32,
340
+ local_rank = local_rank,
341
+ ddp_backend = ddp_backend,
342
+ tpu_num_cores = tpu_num_cores,
343
+ tpu_metrics_debug = tpu_metrics_debug,
344
+ debug = debug,
345
+ dataloader_drop_last = dataloader_drop_last,
346
+ eval_steps = eval_steps,
347
+ dataloader_num_workers = dataloader_num_workers,
348
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
349
+ past_index = past_index,
350
+ run_name = run_name,
351
+ disable_tqdm = disable_tqdm,
352
+ remove_unused_columns = remove_unused_columns,
353
+ label_names = label_names,
354
+ load_best_model_at_end = load_best_model_at_end,
355
+ metric_for_best_model = metric_for_best_model,
356
+ greater_is_better = greater_is_better,
357
+ ignore_data_skip = ignore_data_skip,
358
+ fsdp = fsdp,
359
+ fsdp_min_num_params = fsdp_min_num_params,
360
+ fsdp_config = fsdp_config,
361
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
362
+ accelerator_config = accelerator_config,
363
+ deepspeed = deepspeed,
364
+ label_smoothing_factor = label_smoothing_factor,
365
+ optim = optim,
366
+ optim_args = optim_args,
367
+ adafactor = adafactor,
368
+ group_by_length = group_by_length,
369
+ length_column_name = length_column_name,
370
+ report_to = report_to,
371
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
372
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
373
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
374
+ dataloader_pin_memory = dataloader_pin_memory,
375
+ dataloader_persistent_workers = dataloader_persistent_workers,
376
+ skip_memory_metrics = skip_memory_metrics,
377
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
378
+ push_to_hub = push_to_hub,
379
+ resume_from_checkpoint = resume_from_checkpoint,
380
+ hub_model_id = hub_model_id,
381
+ hub_strategy = hub_strategy,
382
+ hub_token = hub_token,
383
+ hub_private_repo = hub_private_repo,
384
+ hub_always_push = hub_always_push,
385
+ hub_revision = hub_revision,
386
+ gradient_checkpointing = gradient_checkpointing,
387
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
388
+ include_inputs_for_metrics = include_inputs_for_metrics,
389
+ eval_do_concat_batches = eval_do_concat_batches,
390
+ fp16_backend = fp16_backend,
391
+ push_to_hub_model_id = push_to_hub_model_id,
392
+ push_to_hub_organization = push_to_hub_organization,
393
+ push_to_hub_token = push_to_hub_token,
394
+ mp_parameters = mp_parameters,
395
+ auto_find_batch_size = auto_find_batch_size,
396
+ full_determinism = full_determinism,
397
+ torchdynamo = torchdynamo,
398
+ ray_scope = ray_scope,
399
+ ddp_timeout = ddp_timeout,
400
+ torch_compile = torch_compile,
401
+ torch_compile_backend = torch_compile_backend,
402
+ torch_compile_mode = torch_compile_mode,
403
+ include_tokens_per_second = include_tokens_per_second,
404
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
405
+ neftune_noise_alpha = neftune_noise_alpha,
406
+ optim_target_modules = optim_target_modules,
407
+ batch_eval_metrics = batch_eval_metrics,
408
+ eval_on_start = eval_on_start,
409
+ use_liger_kernel = use_liger_kernel,
410
+ liger_kernel_config = liger_kernel_config,
411
+ eval_use_gather_object = eval_use_gather_object,
412
+ average_tokens_across_devices = average_tokens_across_devices,
413
+ max_length = max_length,
414
+ max_prompt_length = max_prompt_length,
415
+ max_completion_length = max_completion_length,
416
+ beta = beta,
417
+ label_pad_token_id = label_pad_token_id,
418
+ padding_value = padding_value,
419
+ truncation_mode = truncation_mode,
420
+ disable_dropout = disable_dropout,
421
+ generate_during_eval = generate_during_eval,
422
+ is_encoder_decoder = is_encoder_decoder,
423
+ precompute_ref_log_probs = precompute_ref_log_probs,
424
+ model_init_kwargs = model_init_kwargs,
425
+ ref_model_init_kwargs = ref_model_init_kwargs,
426
+ dataset_num_proc = dataset_num_proc,
427
+ prompt_sample_size = prompt_sample_size,
428
+ min_density_ratio = min_density_ratio,
429
+ max_density_ratio = max_density_ratio,**kwargs)
430
+ self.vllm_sampling_params = vllm_sampling_params
431
+ self.unsloth_num_chunks = unsloth_num_chunks
432
+ self.max_seq_length = max_seq_length
433
+ pass
434
+
435
+ class _UnslothBCOTrainer(Trainer):
436
+ r""""""
437
+
438
+ _tag_names = ["trl", "bco"]
439
+
440
+ def __init__(
441
+ self,
442
+ model: Union[PreTrainedModel, nn.Module, str] = None,
443
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
444
+ args: BCOConfig = None,
445
+ train_dataset: Optional[Dataset] = None,
446
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
447
+ processing_class: Optional[
448
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
449
+ ] = None,
450
+ data_collator: Optional[DataCollator] = None,
451
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
452
+ callbacks: Optional[list[TrainerCallback]] = None,
453
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
454
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
455
+ peft_config: Optional[dict] = None,
456
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
457
+ model_adapter_name: Optional[str] = None,
458
+ ref_adapter_name: Optional[str] = None,
459
+ embedding_func: Optional[Callable] = None,
460
+ embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
461
+ ):
462
+ if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
463
+ raise ImportError(
464
+ "BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`."
465
+ )
466
+
467
+ if type(args) is TrainingArguments:
468
+ raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
469
+
470
+ if not isinstance(model, str) and model is not None and ref_model is model:
471
+ raise ValueError(
472
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
473
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
474
+ )
475
+
476
+ if args.model_init_kwargs is None:
477
+ model_init_kwargs = {}
478
+ elif not isinstance(model, str):
479
+ raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
480
+ else:
481
+ model_init_kwargs = args.model_init_kwargs
482
+ torch_dtype = model_init_kwargs.get("torch_dtype")
483
+ if torch_dtype is not None:
484
+ # Convert to `torch.dtype` if an str is passed
485
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
486
+ torch_dtype = getattr(torch, torch_dtype)
487
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
488
+ raise ValueError(
489
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
490
+ )
491
+ model_init_kwargs["torch_dtype"] = torch_dtype
492
+
493
+ if args.ref_model_init_kwargs is None:
494
+ ref_model_init_kwargs = {}
495
+ elif not isinstance(ref_model, str):
496
+ raise ValueError(
497
+ "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
498
+ )
499
+ else:
500
+ ref_model_init_kwargs = args.ref_model_init_kwargs
501
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
502
+ if torch_dtype is not None:
503
+ # Convert to `torch.dtype` if an str is passed
504
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
505
+ torch_dtype = getattr(torch, torch_dtype)
506
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
507
+ raise ValueError(
508
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
509
+ )
510
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
511
+
512
+ if isinstance(model, str):
513
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
514
+
515
+ if isinstance(ref_model, str):
516
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
517
+
518
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
519
+ # has been called in order to properly call autocast if needed.
520
+ self._peft_has_been_casted_to_bf16 = False
521
+
522
+ if not is_peft_available() and peft_config is not None:
523
+ raise ValueError(
524
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
525
+ )
526
+ elif is_peft_available() and peft_config is not None:
527
+ # if model is a peft model and we have a peft_config, we merge and unload it first
528
+ if isinstance(model, PeftModel):
529
+ model = model.merge_and_unload()
530
+
531
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
532
+ _support_gc_kwargs = hasattr(
533
+ args, "gradient_checkpointing_kwargs"
534
+ ) and "gradient_checkpointing_kwargs" in list(
535
+ inspect.signature(prepare_model_for_kbit_training).parameters
536
+ )
537
+
538
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
539
+
540
+ if _support_gc_kwargs:
541
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
542
+
543
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
544
+ elif args.gradient_checkpointing:
545
+ # For backward compatibility with older versions of transformers
546
+ if hasattr(model, "enable_input_require_grads"):
547
+ model.enable_input_require_grads()
548
+ else:
549
+
550
+ def make_inputs_require_grad(module, input, output):
551
+ output.requires_grad_(True)
552
+
553
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
554
+
555
+ # get peft model with the given config
556
+ model = model
557
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
558
+ peft_module_casting_to_bf16(model)
559
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
560
+ self._peft_has_been_casted_to_bf16 = True
561
+
562
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
563
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
564
+ # fail or completely fail.
565
+ elif args.gradient_checkpointing:
566
+ # For backward compatibility with older versions of transformers
567
+ if hasattr(model, "enable_input_require_grads"):
568
+ model.enable_input_require_grads()
569
+ else:
570
+
571
+ def make_inputs_require_grad(module, input, output):
572
+ output.requires_grad_(True)
573
+
574
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
575
+
576
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
577
+ raise ValueError(
578
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
579
+ " Please install `wandb` or `comet-ml` to resolve."
580
+ )
581
+
582
+ if model is not None:
583
+ self.is_encoder_decoder = model.config.is_encoder_decoder
584
+ elif args.is_encoder_decoder is None:
585
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
586
+ else:
587
+ self.is_encoder_decoder = args.is_encoder_decoder
588
+
589
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
590
+ self.model_adapter_name = model_adapter_name
591
+ self.ref_adapter_name = ref_adapter_name
592
+
593
+ if ref_model:
594
+ self.ref_model = ref_model
595
+ elif self.is_peft_model or args.precompute_ref_log_probs:
596
+ # The `model` with adapters turned off will be used as the reference model
597
+ self.ref_model = None
598
+ else:
599
+ self.ref_model = create_reference_model(model)
600
+
601
+ if processing_class is None:
602
+ raise ValueError(
603
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
604
+ )
605
+ if args.max_length is None:
606
+ warnings.warn(
607
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
608
+ "It will be set to `512` by default, but you should do it yourself in the future.",
609
+ UserWarning,
610
+ )
611
+ max_length = 512
612
+ if args.max_length is not None:
613
+ max_length = args.max_length
614
+
615
+ if args.max_prompt_length is None:
616
+ warnings.warn(
617
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
618
+ "It will be set to `128` by default, but you should do it yourself in the future.",
619
+ UserWarning,
620
+ )
621
+ max_prompt_length = 128
622
+ if args.max_prompt_length is not None:
623
+ max_prompt_length = args.max_prompt_length
624
+
625
+ max_completion_length = None
626
+ if args.max_completion_length is None and self.is_encoder_decoder:
627
+ warnings.warn(
628
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
629
+ " it will be set to `128` by default, but you should do it yourself in the future.",
630
+ UserWarning,
631
+ )
632
+ max_completion_length = 128
633
+ if args.max_completion_length is not None and self.is_encoder_decoder:
634
+ max_completion_length = args.max_completion_length
635
+
636
+ if data_collator is None:
637
+ data_collator = DPODataCollatorWithPadding(
638
+ pad_token_id=processing_class.pad_token_id,
639
+ label_pad_token_id=args.label_pad_token_id,
640
+ is_encoder_decoder=self.is_encoder_decoder,
641
+ )
642
+
643
+ if args.remove_unused_columns:
644
+ args.remove_unused_columns = False
645
+ # warn users
646
+ warnings.warn(
647
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
648
+ " we have set it for you, but you should do it yourself in the future.",
649
+ UserWarning,
650
+ )
651
+
652
+ self.use_dpo_data_collator = True
653
+ else:
654
+ self.use_dpo_data_collator = False
655
+
656
+ # Disable dropout in the model and reference model
657
+ if args.disable_dropout:
658
+ disable_dropout_in_model(model)
659
+ if self.ref_model is not None:
660
+ disable_dropout_in_model(self.ref_model)
661
+
662
+ self.max_length = max_length
663
+ self.generate_during_eval = args.generate_during_eval
664
+ self.label_pad_token_id = args.label_pad_token_id
665
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
666
+ self.max_prompt_length = max_prompt_length
667
+ self.truncation_mode = args.truncation_mode
668
+ self.max_completion_length = max_completion_length
669
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
670
+
671
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
672
+ # keep track of first called to avoid computation of future calls
673
+ self._precomputed_train_ref_log_probs = False
674
+ self._precomputed_eval_ref_log_probs = False
675
+
676
+ # metric
677
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
678
+
679
+ # BCO parameter
680
+ self.beta = args.beta
681
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
682
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
683
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
684
+ warnings.warn(
685
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
686
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
687
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
688
+ "loss.",
689
+ UserWarning,
690
+ )
691
+
692
+ # Underlying Distribution Matching argument
693
+ self.embedding_func = embedding_func
694
+ self.embedding_tokenizer = embedding_tokenizer
695
+
696
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
697
+ # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
698
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
699
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
700
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
701
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
702
+ # issued.
703
+ model.warnings_issued["estimate_tokens"] = True
704
+
705
+ with PartialState().main_process_first():
706
+ # Apply the chat template if needed
707
+ train_dataset = train_dataset.map(
708
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
709
+ )
710
+ if eval_dataset is not None:
711
+ eval_dataset = eval_dataset.map(
712
+ maybe_apply_chat_template,
713
+ fn_kwargs={"tokenizer": processing_class},
714
+ num_proc=args.dataset_num_proc,
715
+ )
716
+
717
+ # Tokenize and prepare the training datasets
718
+ train_dataset = train_dataset.map(
719
+ _tokenize,
720
+ batched=True,
721
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
722
+ num_proc=args.dataset_num_proc,
723
+ desc="Tokenizing train dataset",
724
+ )
725
+
726
+ # Prepare the datasets
727
+ fn_kwargs = {
728
+ "prefix": "",
729
+ "is_encoder_decoder": self.is_encoder_decoder,
730
+ "tokenizer": processing_class,
731
+ "max_length": self.max_length,
732
+ "truncation_mode": self.truncation_mode,
733
+ "label_pad_token_id": self.label_pad_token_id,
734
+ "max_prompt_length": self.max_prompt_length,
735
+ "max_completion_length": self.max_completion_length,
736
+ }
737
+ train_dataset = train_dataset.map(
738
+ _process_tokens,
739
+ fn_kwargs=fn_kwargs,
740
+ num_proc=args.dataset_num_proc,
741
+ desc="Processing tokenized train dataset",
742
+ )
743
+
744
+ if eval_dataset is not None:
745
+ # Tokenize
746
+ eval_dataset = eval_dataset.map(
747
+ _tokenize,
748
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
749
+ batched=True,
750
+ num_proc=args.dataset_num_proc,
751
+ desc="Tokenizing eval dataset",
752
+ )
753
+
754
+ # Process
755
+ fn_kwargs = {
756
+ "prefix": "",
757
+ "is_encoder_decoder": self.is_encoder_decoder,
758
+ "tokenizer": processing_class,
759
+ "max_length": self.max_length,
760
+ "truncation_mode": self.truncation_mode,
761
+ "label_pad_token_id": self.label_pad_token_id,
762
+ "max_prompt_length": self.max_prompt_length,
763
+ "max_completion_length": self.max_completion_length,
764
+ }
765
+ eval_dataset = eval_dataset.map(
766
+ _process_tokens,
767
+ fn_kwargs=fn_kwargs,
768
+ num_proc=args.dataset_num_proc,
769
+ desc="Processing tokenized eval dataset",
770
+ )
771
+
772
+ desirable = train_dataset.filter(
773
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
774
+ )
775
+ undesirable = train_dataset.filter(
776
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
777
+ )
778
+
779
+ super().__init__(
780
+ model=model,
781
+ args=args,
782
+ data_collator=data_collator,
783
+ train_dataset=train_dataset,
784
+ eval_dataset=eval_dataset,
785
+ processing_class=processing_class,
786
+ model_init=model_init,
787
+ compute_metrics=compute_metrics,
788
+ callbacks=callbacks,
789
+ optimizers=optimizers,
790
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
791
+ )
792
+
793
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
794
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
795
+ # self.model_accepts_loss_kwargs to False to enable scaling.
796
+ self.model_accepts_loss_kwargs = False
797
+
798
+ # Add tags for models that have been loaded with the correct transformers version
799
+ if hasattr(self.model, "add_model_tags"):
800
+ self.model.add_model_tags(self._tag_names)
801
+
802
+ if not hasattr(self, "accelerator"):
803
+ raise AttributeError(
804
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
805
+ )
806
+
807
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
808
+ if self.is_deepspeed_enabled:
809
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
810
+ raise ValueError(
811
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
812
+ )
813
+
814
+ if self.ref_model is None:
815
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
816
+ raise ValueError(
817
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
818
+ )
819
+ else:
820
+ if self.is_deepspeed_enabled:
821
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
822
+ else:
823
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
824
+
825
+ self.running = RunningMoments(accelerator=self.accelerator)
826
+
827
+ if self.embedding_func is None or args.resume_from_checkpoint:
828
+ return
829
+
830
+ chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
831
+ rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
832
+
833
+ embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
834
+ labels = torch.cat(
835
+ (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
836
+ )
837
+
838
+ self.clf = LogisticRegression(class_weight="balanced").fit(
839
+ embeddings.cpu().float().numpy(), labels.cpu().numpy()
840
+ )
841
+ chosen_mean = self.clf.score(
842
+ chosen_embeddings.cpu().float().numpy(), torch.ones_like(chosen_embeddings[:, 0]).cpu().numpy()
843
+ )
844
+ rejected_mean = self.clf.score(
845
+ rejected_embeddings.cpu().float().numpy(), torch.zeros_like(rejected_embeddings[:, 0]).cpu().numpy()
846
+ )
847
+ logger.info(f"UDM classifier training scores: chosen: {chosen_mean}, rejected: {rejected_mean}")
848
+
849
+ @property
850
+ def match_underlying_distribution(self):
851
+ return self.embedding_func is not None and self.embedding_tokenizer is not None
852
+
853
+ def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
854
+ """
855
+ Calculates the probability if the given prompt embedding is from desirable dataset. This function calculates
856
+ the probability in the process and ensemble across processes.
857
+ """
858
+ dtype = prompt_embeddings.dtype
859
+ device = prompt_embeddings.device
860
+ rank = self.accelerator.process_index
861
+
862
+ padded_prompt_embeddings = self.accelerator.pad_across_processes(
863
+ prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
864
+ )
865
+ sample_size = padded_prompt_embeddings.shape[0]
866
+ nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
867
+ prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
868
+
869
+ # cannot predict for all empty values
870
+ if prompt_embeddings.shape[0] == 0:
871
+ return torch.tensor([], device=device, dtype=dtype)
872
+
873
+ prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
874
+ prob = torch.as_tensor(prob, dtype=dtype, device=device)
875
+ prob = self.accelerator.reduce(prob, reduction="mean")
876
+
877
+ prob = prob[sample_size * rank : sample_size * (rank + 1)]
878
+ prob = prob[nonzero]
879
+
880
+ return prob
881
+
882
+ def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
883
+ """
884
+ Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id and applies self.embedding_func
885
+ """
886
+ input_ids = torch.where(
887
+ input_ids == self.processing_class.pad_token_id,
888
+ self.embedding_tokenizer.pad_token_id,
889
+ input_ids,
890
+ )
891
+
892
+ with torch.no_grad():
893
+ embeddings = self.embedding_func(
894
+ input_ids=input_ids,
895
+ attention_mask=attention_mask,
896
+ )
897
+
898
+ return embeddings
899
+
900
+ def _get_prompt_embeddings(
901
+ self, batch: dict[str, Union[list, torch.LongTensor]]
902
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
903
+ """Extract embeddings from frozen embedding model"""
904
+
905
+ if not self.match_underlying_distribution:
906
+ return None, None
907
+
908
+ embeddings = self._vectorize_prompt(
909
+ input_ids=batch["embedding_input_ids"],
910
+ attention_mask=batch["embedding_attention_mask"],
911
+ )
912
+
913
+ labels = torch.tensor(batch["label"], dtype=torch.bool, device=embeddings.device)
914
+ chosen_idx = torch.where(labels)[0]
915
+ rejected_idx = torch.where(~labels)[0]
916
+
917
+ chosen_embeddings = embeddings[chosen_idx, ...]
918
+ rejected_embeddings = embeddings[rejected_idx, ...]
919
+
920
+ return (chosen_embeddings, rejected_embeddings)
921
+
922
+ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
923
+ """
924
+ Sample instances from dataset and get prompt embeddings. Used for density ratio classifier training.
925
+ """
926
+ n_samples = min(len(dataset), sample_size)
927
+ rand_indices = np.random.choice(len(dataset), size=(n_samples,))
928
+
929
+ embedding_dataset = dataset.select(rand_indices)
930
+
931
+ dataloader_params = {
932
+ "batch_size": self.args.per_device_train_batch_size,
933
+ "collate_fn": self.data_collator,
934
+ "num_workers": self.args.dataloader_num_workers,
935
+ "pin_memory": self.args.dataloader_pin_memory,
936
+ "shuffle": False,
937
+ }
938
+
939
+ # prepare dataloader
940
+ data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
941
+
942
+ with torch.no_grad():
943
+ all_embeddings = torch.empty(0)
944
+ for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
945
+ embeddings = self._vectorize_prompt(
946
+ input_ids=padded_batch["embedding_input_ids"],
947
+ attention_mask=padded_batch["embedding_attention_mask"],
948
+ )
949
+ embeddings = self.accelerator.gather_for_metrics(embeddings)
950
+ all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
951
+
952
+ return all_embeddings
953
+
954
+ def _save_optimizer_and_scheduler(self, output_dir):
955
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
956
+ super()._save_optimizer_and_scheduler(output_dir)
957
+
958
+ if self.accelerator.is_main_process:
959
+ # When saving optimizer and scheduler to checkpoint, save also the running delta object.
960
+ self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
961
+
962
+ if self.match_underlying_distribution:
963
+ joblib.dump(self.clf, os.path.join(output_dir, CLF_NAME), compress=True)
964
+
965
+ def _load_optimizer_and_scheduler(self, checkpoint):
966
+ if checkpoint is None:
967
+ logger.warning_once(f"Missing Checkpoint {checkpoint}")
968
+ return
969
+
970
+ super()._load_optimizer_and_scheduler(checkpoint)
971
+
972
+ # when loading optimizer and scheduler from checkpoint, also load the running delta object.
973
+ running_file = os.path.join(checkpoint, RUNNING_NAME)
974
+ if os.path.isfile(running_file):
975
+ self.running = RunningMoments.load_from_json(self.accelerator, running_file)
976
+
977
+ if self.match_underlying_distribution:
978
+ clf_file = os.path.join(checkpoint, CLF_NAME)
979
+ if os.path.isfile(clf_file):
980
+ self.clf = joblib.load(clf_file)
981
+
982
+ @contextmanager
983
+ def null_ref_context(self):
984
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
985
+ with (
986
+ self.accelerator.unwrap_model(self.model).disable_adapter()
987
+ if self.is_peft_model and not self.ref_adapter_name
988
+ else nullcontext()
989
+ ):
990
+ if self.ref_adapter_name:
991
+ self.model.set_adapter(self.ref_adapter_name)
992
+ yield
993
+ if self.ref_adapter_name:
994
+ self.model.set_adapter(self.model_adapter_name or "default")
995
+
996
+ def get_train_dataloader(self) -> DataLoader:
997
+ """
998
+ Returns the training [`~torch.utils.data.DataLoader`].
999
+
1000
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1001
+ """
1002
+
1003
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1004
+ dataloader_params = {
1005
+ "batch_size": self.args.per_device_train_batch_size,
1006
+ "collate_fn": self.data_collator,
1007
+ "num_workers": self.args.dataloader_num_workers,
1008
+ "pin_memory": self.args.dataloader_pin_memory,
1009
+ "shuffle": False,
1010
+ }
1011
+
1012
+ # prepare dataloader
1013
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1014
+ reference_completion_logps = []
1015
+
1016
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1017
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1018
+
1019
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1020
+ reference_completion_logps.append(reference_completion_logp.cpu())
1021
+
1022
+ self.train_dataset = self.train_dataset.add_column(
1023
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1024
+ )
1025
+
1026
+ self._precomputed_train_ref_log_probs = True
1027
+
1028
+ return super().get_train_dataloader()
1029
+
1030
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1031
+ """
1032
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
1033
+
1034
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1035
+
1036
+ Args:
1037
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
1038
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1039
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
1040
+ """
1041
+ if eval_dataset is None and self.eval_dataset is None:
1042
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1043
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1044
+
1045
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1046
+ dataloader_params = {
1047
+ "batch_size": self.args.per_device_eval_batch_size,
1048
+ "collate_fn": self.data_collator,
1049
+ "num_workers": self.args.dataloader_num_workers,
1050
+ "pin_memory": self.args.dataloader_pin_memory,
1051
+ "shuffle": False,
1052
+ }
1053
+
1054
+ # prepare dataloader
1055
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1056
+
1057
+ reference_completion_logps = []
1058
+
1059
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1060
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1061
+
1062
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1063
+ reference_completion_logps.append(reference_completion_logp.cpu())
1064
+
1065
+ eval_dataset = eval_dataset.add_column(
1066
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1067
+ )
1068
+
1069
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1070
+ if self.eval_dataset is not None:
1071
+ self.eval_dataset = eval_dataset
1072
+ self._precomputed_eval_ref_log_probs = True
1073
+
1074
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1075
+
1076
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1077
+ """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1078
+ with torch.no_grad():
1079
+ if self.ref_model is None:
1080
+ with self.null_ref_context():
1081
+ if self.is_encoder_decoder:
1082
+ completion_logits = self.model(
1083
+ padded_batch["prompt_input_ids"],
1084
+ attention_mask=padded_batch["prompt_attention_mask"],
1085
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1086
+ labels=padded_batch["completion_labels"],
1087
+ ).logits
1088
+
1089
+ else:
1090
+ completion_logits = self.model(
1091
+ padded_batch["completion_input_ids"],
1092
+ attention_mask=padded_batch["completion_attention_mask"],
1093
+ ).logits
1094
+
1095
+ else:
1096
+ if self.is_encoder_decoder:
1097
+ completion_logits = self.ref_model(
1098
+ padded_batch["prompt_input_ids"],
1099
+ attention_mask=padded_batch["prompt_attention_mask"],
1100
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1101
+ labels=padded_batch["completion_labels"],
1102
+ ).logits
1103
+
1104
+ else:
1105
+ completion_logits = self.ref_model(
1106
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1107
+ ).logits
1108
+
1109
+ completion_logps = self.get_batch_logps(
1110
+ completion_logits,
1111
+ padded_batch["completion_labels"],
1112
+ average_log_prob=False,
1113
+ is_encoder_decoder=self.is_encoder_decoder,
1114
+ label_pad_token_id=self.label_pad_token_id,
1115
+ )
1116
+
1117
+ return completion_logps
1118
+
1119
+ @staticmethod
1120
+ def get_batch_logps(
1121
+ logits: torch.FloatTensor,
1122
+ labels: torch.LongTensor,
1123
+ average_log_prob: bool = False,
1124
+ label_pad_token_id: int = -100,
1125
+ is_encoder_decoder: bool = False,
1126
+ ) -> torch.FloatTensor:
1127
+ """Compute the log probabilities of the given labels under the given logits.
1128
+
1129
+ Args:
1130
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1131
+ labels:
1132
+ Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
1133
+ ignored. Shape: (batch_size, sequence_length)
1134
+ average_log_prob:
1135
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
1136
+ log probabilities of the (non-masked) tokens.
1137
+
1138
+ Returns:
1139
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
1140
+ given logits.
1141
+ """
1142
+ if logits.shape[:-1] != labels.shape:
1143
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1144
+
1145
+ if not is_encoder_decoder:
1146
+ labels = labels[:, 1:].clone()
1147
+ logits = logits[:, :-1, :]
1148
+ else:
1149
+ # Fixes end-dec RuntimeError
1150
+ labels = labels.clone()
1151
+
1152
+ loss_mask = labels != label_pad_token_id
1153
+
1154
+ # dummy token; we'll ignore the losses on these tokens later
1155
+ labels[labels == label_pad_token_id] = 0
1156
+
1157
+ per_token_logps = selective_log_softmax(logits, labels)
1158
+
1159
+ if average_log_prob:
1160
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1161
+ else:
1162
+ return (per_token_logps * loss_mask).sum(-1)
1163
+
1164
+ def forward(
1165
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1166
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1167
+ model_kwargs = (
1168
+ {
1169
+ "labels": batch["completion_labels"],
1170
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1171
+ }
1172
+ if self.is_encoder_decoder
1173
+ else {}
1174
+ )
1175
+ if self.aux_loss_enabled:
1176
+ model_kwargs["output_router_logits"] = True
1177
+
1178
+ outputs = model(
1179
+ batch["completion_input_ids"],
1180
+ attention_mask=batch["completion_attention_mask"],
1181
+ **model_kwargs,
1182
+ )
1183
+ completion_logits = outputs.logits
1184
+
1185
+ completion_logps = self.get_batch_logps(
1186
+ completion_logits,
1187
+ batch["completion_labels"],
1188
+ average_log_prob=False,
1189
+ is_encoder_decoder=self.is_encoder_decoder,
1190
+ label_pad_token_id=self.label_pad_token_id,
1191
+ )
1192
+
1193
+ if completion_logps.shape[0] != len(batch["label"]):
1194
+ raise ValueError(
1195
+ "There is a mismatch between the number of examples in this batch and the number of "
1196
+ "examples for which an output sequence was predicted."
1197
+ )
1198
+
1199
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1200
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1201
+
1202
+ chosen_logps = completion_logps[chosen_idx, ...]
1203
+ rejected_logps = completion_logps[rejected_idx, ...]
1204
+
1205
+ chosen_logits = completion_logits[chosen_idx, ...]
1206
+ rejected_logits = completion_logits[rejected_idx, ...]
1207
+
1208
+ if self.aux_loss_enabled:
1209
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1210
+ else:
1211
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1212
+
1213
+ def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1214
+ prob_desirable = self._get_chosen_prob(rejected_embeddings)
1215
+ min_ratio = self.args.min_density_ratio
1216
+ max_ratio = self.args.max_density_ratio
1217
+
1218
+ weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1219
+
1220
+ return weight
1221
+
1222
+ def bco_loss(
1223
+ self,
1224
+ policy_chosen_logps: torch.FloatTensor,
1225
+ policy_rejected_logps: torch.FloatTensor,
1226
+ reference_chosen_logps: torch.FloatTensor,
1227
+ reference_rejected_logps: torch.FloatTensor,
1228
+ chosen_embeddings: Optional[torch.FloatTensor],
1229
+ rejected_embeddings: Optional[torch.FloatTensor],
1230
+ do_train: bool = True,
1231
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1232
+ """Compute the BCO loss for a batch of policy and reference model log probabilities.
1233
+
1234
+ Args:
1235
+ policy_chosen_logps:
1236
+ Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1237
+ policy_rejected_logps:
1238
+ Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1239
+ reference_chosen_logps:
1240
+ Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1241
+ reference_rejected_logps:
1242
+ Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in
1243
+ batch_size,)
1244
+ chosen_embeddings: embeddings of desirable prompts
1245
+ rejected_embeddings: embeddings of undesirable prompts
1246
+
1247
+ Returns:
1248
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta). The losses tensor contains the
1249
+ BCO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards
1250
+ for the chosen and rejected responses, respectively. The delta value contains the moving average of all
1251
+ implicit rewards.
1252
+ """
1253
+
1254
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1255
+ chosen_rewards = self.beta * chosen_logratios
1256
+
1257
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1258
+ rejected_rewards = self.beta * rejected_logratios
1259
+
1260
+ if do_train:
1261
+ self.running.update(torch.cat((chosen_rewards, rejected_rewards), 0).detach())
1262
+ delta = torch.as_tensor(self.running.mean, device=chosen_rewards.device)
1263
+
1264
+ chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1265
+ rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1266
+
1267
+ if self.match_underlying_distribution:
1268
+ chosen_weight = torch.ones_like(chosen_losses)
1269
+ rejected_weight = self._get_udm_weight(rejected_embeddings)
1270
+
1271
+ losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1272
+ else:
1273
+ losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1274
+
1275
+ return losses, chosen_rewards, rejected_rewards, delta
1276
+
1277
+ def get_batch_loss_metrics(
1278
+ self,
1279
+ model,
1280
+ batch: dict[str, Union[list, torch.LongTensor]],
1281
+ do_train: bool = True,
1282
+ ):
1283
+ """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1284
+ metrics = {}
1285
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1286
+
1287
+ forward_output = self.forward(model, batch)
1288
+ (
1289
+ policy_chosen_logps,
1290
+ policy_rejected_logps,
1291
+ policy_chosen_logits,
1292
+ policy_rejected_logits,
1293
+ ) = forward_output[:4]
1294
+ if self.aux_loss_enabled:
1295
+ aux_loss = forward_output[4]
1296
+
1297
+ # if reference_logps in batch use them, otherwise use the reference model
1298
+ if "reference_logps" in batch:
1299
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1300
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1301
+
1302
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1303
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1304
+ else:
1305
+ with torch.no_grad():
1306
+ if self.ref_model is None:
1307
+ with self.null_ref_context():
1308
+ (
1309
+ reference_chosen_logps,
1310
+ reference_rejected_logps,
1311
+ _,
1312
+ _,
1313
+ ) = self.forward(self.model, batch)[:4]
1314
+ else:
1315
+ (
1316
+ reference_chosen_logps,
1317
+ reference_rejected_logps,
1318
+ _,
1319
+ _,
1320
+ ) = self.forward(self.ref_model, batch)[:4]
1321
+
1322
+ chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1323
+
1324
+ losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1325
+ policy_chosen_logps,
1326
+ policy_rejected_logps,
1327
+ reference_chosen_logps,
1328
+ reference_rejected_logps,
1329
+ chosen_embeddings,
1330
+ rejected_embeddings,
1331
+ do_train=do_train,
1332
+ )
1333
+ metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1334
+
1335
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1336
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1337
+
1338
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1339
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1340
+
1341
+ if all_num_chosen > 0:
1342
+ metrics["rewards/chosen_sum"] = (
1343
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1344
+ )
1345
+ metrics["logps/chosen_sum"] = (
1346
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1347
+ )
1348
+ metrics["logits/chosen_sum"] = (
1349
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1350
+ )
1351
+ metrics["count/chosen"] = all_num_chosen
1352
+
1353
+ if all_num_rejected > 0:
1354
+ metrics["rewards/rejected_sum"] = (
1355
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1356
+ )
1357
+ metrics["logps/rejected_sum"] = (
1358
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1359
+ )
1360
+ metrics["logits/rejected_sum"] = (
1361
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1362
+ )
1363
+ metrics["count/rejected"] = all_num_rejected
1364
+
1365
+ loss = losses.nanmean()
1366
+ if self.aux_loss_enabled:
1367
+ loss += self.aux_loss_coef * aux_loss
1368
+
1369
+ return loss, metrics
1370
+
1371
+ def compute_loss(
1372
+ self,
1373
+ model: Union[PreTrainedModel, nn.Module],
1374
+ inputs: dict[str, Union[torch.Tensor, Any]],
1375
+ return_outputs=False,
1376
+ num_items_in_batch=None,
1377
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1378
+ compute_loss_context_manager = (
1379
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1380
+ )
1381
+
1382
+ with compute_loss_context_manager:
1383
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1384
+
1385
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1386
+ loss = loss.to(self.args.device)
1387
+ # force log the metrics
1388
+ if self.accelerator.is_main_process:
1389
+ self.store_metrics(metrics, train_eval="train")
1390
+
1391
+ if return_outputs:
1392
+ return (loss, metrics)
1393
+ return loss
1394
+
1395
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1396
+ for key, value in metrics.items():
1397
+ self._stored_metrics[train_eval][key].append(value)
1398
+
1399
+ def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
1400
+ if dataset is None:
1401
+ dataset = self.train_dataset
1402
+ if dataset is None or not has_length(dataset):
1403
+ return None
1404
+ return SequentialSampler(dataset)
1405
+
1406
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1407
+ """Generate samples from the model and reference model for the given batch of inputs."""
1408
+
1409
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1410
+ # the torch amp context manager as some hidden states are silently casted to full precision.
1411
+ generate_context_manager = (
1412
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1413
+ )
1414
+ with generate_context_manager:
1415
+ policy_output = model.generate(
1416
+ input_ids=batch["prompt_input_ids"],
1417
+ attention_mask=batch["prompt_attention_mask"],
1418
+ max_length=self.max_length,
1419
+ do_sample=True,
1420
+ pad_token_id=self.processing_class.pad_token_id,
1421
+ )
1422
+
1423
+ # if reference_output in batch use that otherwise use the reference model
1424
+ if "reference_output" in batch:
1425
+ reference_output = batch["reference_output"]
1426
+ else:
1427
+ if self.ref_model is None:
1428
+ with self.null_ref_context():
1429
+ reference_output = self.model.generate(
1430
+ input_ids=batch["prompt_input_ids"],
1431
+ attention_mask=batch["prompt_attention_mask"],
1432
+ max_length=self.max_length,
1433
+ do_sample=True,
1434
+ pad_token_id=self.processing_class.pad_token_id,
1435
+ )
1436
+ else:
1437
+ reference_output = self.ref_model.generate(
1438
+ input_ids=batch["prompt_input_ids"],
1439
+ attention_mask=batch["prompt_attention_mask"],
1440
+ max_length=self.max_length,
1441
+ do_sample=True,
1442
+ pad_token_id=self.processing_class.pad_token_id,
1443
+ )
1444
+
1445
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1446
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1447
+
1448
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1449
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1450
+
1451
+ return policy_output_decoded, reference_output_decoded
1452
+
1453
+ def prediction_step(
1454
+ self,
1455
+ model: Union[PreTrainedModel, nn.Module],
1456
+ inputs: dict[str, Union[torch.Tensor, Any]],
1457
+ prediction_loss_only: bool,
1458
+ ignore_keys: Optional[list[str]] = None,
1459
+ ):
1460
+ if ignore_keys is None:
1461
+ if hasattr(model, "config"):
1462
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1463
+ else:
1464
+ ignore_keys = []
1465
+
1466
+ prediction_context_manager = (
1467
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1468
+ )
1469
+ with torch.no_grad(), prediction_context_manager:
1470
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, do_train=False)
1471
+
1472
+ # force log the metrics
1473
+ if self.accelerator.is_main_process:
1474
+ self.store_metrics(metrics, train_eval="eval")
1475
+
1476
+ if prediction_loss_only:
1477
+ return (loss.detach(), None, None)
1478
+
1479
+ # logits for the chosen and rejected samples from model
1480
+ logits_dict = {}
1481
+ if "logits/chosen_sum" in metrics:
1482
+ logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
1483
+ if "logits/rejected_sum" in metrics:
1484
+ logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
1485
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1486
+ logits = torch.tensor(logits, device=self.accelerator.device)
1487
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1488
+
1489
+ return (loss.detach(), logits, labels)
1490
+
1491
+ def evaluation_loop(
1492
+ self,
1493
+ dataloader: DataLoader,
1494
+ description: str,
1495
+ prediction_loss_only: Optional[bool] = None,
1496
+ ignore_keys: Optional[list[str]] = None,
1497
+ metric_key_prefix: str = "eval",
1498
+ ) -> EvalLoopOutput:
1499
+ """
1500
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
1501
+ `Trainer.evaluate()` and `Trainer.predict()`.
1502
+
1503
+ Works both with or without labels.
1504
+ """
1505
+
1506
+ # Sample and save to game log if requested (for one batch to save time)
1507
+ if self.generate_during_eval:
1508
+ # Generate random indices within the range of the total number of samples
1509
+ num_samples = len(dataloader.dataset)
1510
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1511
+
1512
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1513
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1514
+ random_batch = self.data_collator(random_batch_dataset)
1515
+ random_batch = self._prepare_inputs(random_batch)
1516
+
1517
+ target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device)
1518
+ target_indicies = torch.where(~target_labels)[0]
1519
+ target_batch = {
1520
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1521
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1522
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1523
+ }
1524
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1525
+
1526
+ table = pd.DataFrame(
1527
+ columns=["Prompt", "Policy", "Ref Model"],
1528
+ data=[
1529
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1530
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1531
+ ],
1532
+ )
1533
+ if "wandb" in self.args.report_to:
1534
+ wandb.log({"game_log": wandb.Table(data=table)})
1535
+
1536
+ if "comet_ml" in self.args.report_to:
1537
+ log_table_to_comet_experiment(
1538
+ name="game_log.csv",
1539
+ table=table,
1540
+ )
1541
+
1542
+ # Base evaluation
1543
+ initial_output = super().evaluation_loop(
1544
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1545
+ )
1546
+
1547
+ return initial_output
1548
+
1549
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1550
+ """
1551
+ Log `logs` on the various objects watching training, including stored metrics.
1552
+
1553
+ Args:
1554
+ logs (`dict[str, float]`):
1555
+ The values to log.
1556
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1557
+ Start time of the training.
1558
+ """
1559
+ # logs either has 'loss' or 'eval_loss'
1560
+ train_eval = "train" if "loss" in logs else "eval"
1561
+ # train metrics should have no prefix, eval should have 'eval_'
1562
+ prefix = "eval_" if train_eval == "eval" else ""
1563
+ # accumulate average metrics from sums and lengths
1564
+ for split in ["chosen", "rejected"]:
1565
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1566
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1567
+ for metric in ["rewards", "logps", "logits"]:
1568
+ logs[f"{prefix}{metric}/{split}"] = (
1569
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1570
+ / count_sum
1571
+ )
1572
+ # delete obsolete metric
1573
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1574
+ del self._stored_metrics[train_eval][f"count/{split}"]
1575
+ # calculate reward margin
1576
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1577
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1578
+ # Add averaged stored metrics to logs
1579
+ for key, metrics in self._stored_metrics[train_eval].items():
1580
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1581
+ del self._stored_metrics[train_eval]
1582
+ return super().log(logs, start_time)
1583
+
1584
+ # Ensure the model card is saved along with the checkpoint
1585
+ def _save_checkpoint(self, model, trial):
1586
+ if self.args.hub_model_id is None:
1587
+ model_name = Path(self.args.output_dir).name
1588
+ else:
1589
+ model_name = self.args.hub_model_id.split("/")[-1]
1590
+ self.create_model_card(model_name=model_name)
1591
+ super()._save_checkpoint(model, trial)
1592
+
1593
+ def create_model_card(
1594
+ self,
1595
+ model_name: Optional[str] = None,
1596
+ dataset_name: Optional[str] = None,
1597
+ tags: Union[str, list[str], None] = None,
1598
+ ):
1599
+ """
1600
+ Creates a draft of a model card using the information available to the `Trainer`.
1601
+
1602
+ Args:
1603
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1604
+ Name of the model.
1605
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1606
+ Name of the dataset used for training.
1607
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1608
+ Tags to be associated with the model card.
1609
+ """
1610
+ if not self.is_world_process_zero():
1611
+ return
1612
+
1613
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1614
+ base_model = self.model.config._name_or_path
1615
+ else:
1616
+ base_model = None
1617
+
1618
+ # normalize `tags` to a mutable set
1619
+ if tags is None:
1620
+ tags = set()
1621
+ elif isinstance(tags, str):
1622
+ tags = {tags}
1623
+ else:
1624
+ tags = set(tags)
1625
+
1626
+ if hasattr(self.model.config, "unsloth_version"):
1627
+ tags.add("unsloth")
1628
+
1629
+ tags.update(self._tag_names)
1630
+
1631
+ citation = textwrap.dedent("""\
1632
+ @article{jung2024binary,
1633
+ title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1634
+ author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1635
+ year = 2024,
1636
+ eprint = {arXiv:2404.04656}
1637
+ }""")
1638
+
1639
+ model_card = generate_model_card(
1640
+ base_model=base_model,
1641
+ model_name=model_name,
1642
+ hub_model_id=self.hub_model_id,
1643
+ dataset_name=dataset_name,
1644
+ tags=tags,
1645
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1646
+ comet_url=get_comet_experiment_url(),
1647
+ trainer_name="BCO",
1648
+ trainer_citation=citation,
1649
+ paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1650
+ paper_id="2404.04656",
1651
+ )
1652
+
1653
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1654
+ class UnslothBCOTrainer(_UnslothBCOTrainer):
1655
+ """
1656
+
1657
+ Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1658
+
1659
+ Args:
1660
+ model (`transformers.PreTrainedModel`):
1661
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1662
+ ref_model (`PreTrainedModelWrapper`):
1663
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
1664
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
1665
+ architecture as the model to be optimized.
1666
+ args (`BCOConfig`):
1667
+ The arguments to use for training.
1668
+ train_dataset (`datasets.Dataset`):
1669
+ The dataset to use for training.
1670
+ eval_dataset (`datasets.Dataset`):
1671
+ The dataset to use for evaluation.
1672
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
1673
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1674
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1675
+ reuse the fine-tuned model.
1676
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1677
+ The data collator to use for training. If None is specified, the default data collator
1678
+ (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
1679
+ sequences in the batch, given a dataset of paired sequences.
1680
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1681
+ The model initializer to use for training. If None is specified, the default model initializer will be
1682
+ used.
1683
+ callbacks (`list[transformers.TrainerCallback]`):
1684
+ The callbacks to use for training.
1685
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1686
+ The optimizer and scheduler to use for training.
1687
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1688
+ The function to use to preprocess the logits before computing the metrics.
1689
+ peft_config (`dict`, defaults to `None`):
1690
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
1691
+ a PEFT model.
1692
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1693
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1694
+ metric values.
1695
+ model_adapter_name (`str`, defaults to `None`):
1696
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1697
+ ref_adapter_name (`str`, defaults to `None`):
1698
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1699
+
1700
+ """
1701
+ def __init__(
1702
+ self,
1703
+ model = None,
1704
+ ref_model = None,
1705
+ args = None,
1706
+ train_dataset = None,
1707
+ eval_dataset = None,
1708
+ processing_class = None,
1709
+ data_collator = None,
1710
+ model_init = None,
1711
+ callbacks = None,
1712
+ preprocess_logits_for_metrics = None,
1713
+ peft_config = None,
1714
+ compute_metrics = None,
1715
+ model_adapter_name = None,
1716
+ ref_adapter_name = None,
1717
+ embedding_func = None,
1718
+ embedding_tokenizer = None,
1719
+ **kwargs
1720
+ ):
1721
+ if args is None: args = UnslothBCOConfig()
1722
+ use_bf16 = getattr(args, 'bf16', False)
1723
+ if type(use_bf16) is not bool: use_bf16 = False
1724
+ use_fp16 = getattr(args, 'fp16', False)
1725
+ if type(use_fp16) is not bool: use_fp16 = False
1726
+ force_float32 = False
1727
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1728
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1729
+ force_float32 = True
1730
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1731
+ dtype = getattr(model.config, 'torch_dtype', None)
1732
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1733
+ from unsloth_zoo.utils import _get_dtype
1734
+ dtype = _get_dtype(dtype)
1735
+ float16 = dtype == torch.float16
1736
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1737
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1738
+ if force_float32:
1739
+ args.fp16 = False
1740
+ args.bf16 = False
1741
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1742
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1743
+ args.fp16 = float16
1744
+ args.bf16 = not float16
1745
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1746
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1747
+ args.eval_strategy = 'steps'
1748
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1749
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1750
+ if ga_steps is not None and ga_steps > 1:
1751
+ from transformers import __version__ as transformers_version
1752
+ if Version(transformers_version) <= Version('4.45.2'):
1753
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1754
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1755
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1756
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1757
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1758
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1759
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1760
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1761
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1762
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1763
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1764
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1765
+ if force_float32:
1766
+ args.bf16_full_eval = False
1767
+ args.fp16_full_eval = False
1768
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1769
+ args.bf16_full_eval = True
1770
+ args.fp16_full_eval = False
1771
+ elif not bf16_full_eval and not fp16_full_eval:
1772
+ args.bf16_full_eval = args.bf16
1773
+ args.fp16_full_eval = args.fp16
1774
+ _output_logits = False
1775
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1776
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1777
+ if _output_logits:
1778
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1779
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1780
+ pass
1781
+ else:
1782
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1783
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1784
+ if args_max_seq_length is None and model_max_seq_length is not None:
1785
+ max_seq_length = model.max_seq_length
1786
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1787
+ if model is not None and hasattr(model, 'for_training'):
1788
+ model.for_training()
1789
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1790
+ if 'processing_class' in locals():
1791
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1792
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1793
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1794
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1795
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1796
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1797
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1798
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1799
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1800
+ else:
1801
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1802
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1803
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1804
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1805
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1806
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1807
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1808
+ else:
1809
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1810
+ other_metrics = []
1811
+
1812
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1813
+ PatchRLStatistics('bco_trainer', other_metrics)
1814
+
1815
+ super().__init__(
1816
+ model = model,
1817
+ ref_model = ref_model,
1818
+ args = args,
1819
+ train_dataset = train_dataset,
1820
+ eval_dataset = eval_dataset,
1821
+ processing_class = processing_class,
1822
+ data_collator = data_collator,
1823
+ model_init = model_init,
1824
+ callbacks = callbacks,
1825
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1826
+ peft_config = peft_config,
1827
+ compute_metrics = compute_metrics,
1828
+ model_adapter_name = model_adapter_name,
1829
+ ref_adapter_name = ref_adapter_name,
1830
+ embedding_func = embedding_func,
1831
+ embedding_tokenizer = embedding_tokenizer,**kwargs)
1832
+ if hasattr(self, 'neftune_hook_handle'):
1833
+ self.neftune_hook_handle.remove()
1834
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1835
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1836
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1837
+ pass
1838
+ if hasattr(self, 'accelerator'):
1839
+ scaler = self.accelerator.scaler
1840
+ current_model = model
1841
+ while hasattr(current_model, 'model'):
1842
+ current_model.accelerator_scaler = scaler
1843
+ current_model = current_model.model
1844
+ current_model.accelerator_scaler = scaler
1845
+ pass
1846
+
1847
+ pass
1848
+
1849
+
1850
+ if hasattr(logger, "addFilter"):
1851
+ import logging
1852
+ class HideLoggingMessage(logging.Filter):
1853
+ def __init__(self, text): self.text = text
1854
+ def filter(self, x): return not (self.text in x.getMessage())
1855
+ pass
1856
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
1857
+
unsloth_compiled_cache/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothCPOConfig(CPOConfig):
53
+ """
54
+
55
+ Configuration class for the [`CPOTrainer`].
56
+
57
+ This class includes only the parameters that are specific to CPO training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
59
+ differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
67
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
68
+ to use the default data collator.
69
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
70
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
71
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
72
+ Maximum length of the completion. This argument is required if you want to use the default data collator
73
+ and your model is an encoder-decoder.
74
+ beta (`float`, *optional*, defaults to `0.1`):
75
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
76
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
77
+ the [paper](https://huggingface.co/papers/2310.12036).
78
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
79
+ Label smoothing factor. This argument is required if you want to use the default data collator.
80
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
81
+ Type of loss to use. Possible values are:
82
+
83
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
84
+ - `"hinge"`: hinge loss on the normalized likelihood from the
85
+ [SLiC](https://huggingface.co/papers/2305.10425) paper.
86
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
87
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
88
+
89
+ disable_dropout (`bool`, *optional*, defaults to `True`):
90
+ Whether to disable dropout in the model.
91
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
92
+ Weight of the BC regularizer in CPO training.
93
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
94
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
95
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
96
+ Label pad token id. This argument is required if you want to use the default data collator.
97
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
98
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
99
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
100
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
101
+ This argument is required if you want to use the default data collator.
102
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
103
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
104
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
105
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
106
+ you need to specify if the model returned by the callable is an encoder-decoder model.
107
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
108
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
109
+ string.
110
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
111
+ Number of processes to use for processing the dataset.
112
+
113
+ """
114
+ vllm_sampling_params: Optional[Any] = field(
115
+ default = None,
116
+ metadata = {'help': 'vLLM SamplingParams'},
117
+ )
118
+ unsloth_num_chunks : Optional[int] = field(
119
+ default = -1,
120
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
121
+ )
122
+ max_seq_length : Optional[int] = field(
123
+ default = None,
124
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
125
+ )
126
+ def __init__(
127
+ self,
128
+ output_dir = None,
129
+ overwrite_output_dir = None,
130
+ do_train = False,
131
+ do_eval = False,
132
+ do_predict = False,
133
+ eval_strategy = 'no',
134
+ prediction_loss_only = False,
135
+ per_device_train_batch_size = 4,
136
+ per_device_eval_batch_size = 4,
137
+ per_gpu_train_batch_size = None,
138
+ per_gpu_eval_batch_size = None,
139
+ gradient_accumulation_steps = 2,
140
+ eval_accumulation_steps = 2,
141
+ eval_delay = 0,
142
+ torch_empty_cache_steps = 250,
143
+ learning_rate = 5e-05,
144
+ weight_decay = 0.01,
145
+ adam_beta1 = 0.9,
146
+ adam_beta2 = 0.999,
147
+ adam_epsilon = 1e-08,
148
+ max_grad_norm = 1.0,
149
+ num_train_epochs = 3.0,
150
+ max_steps = -1,
151
+ lr_scheduler_type = 'linear',
152
+ warmup_ratio = 0.1,
153
+ warmup_steps = 0,
154
+ log_level = 'passive',
155
+ log_level_replica = 'warning',
156
+ log_on_each_node = True,
157
+ logging_dir = None,
158
+ logging_strategy = 'steps',
159
+ logging_first_step = False,
160
+ logging_steps = 1,
161
+ logging_nan_inf_filter = False,
162
+ save_strategy = 'steps',
163
+ save_steps = 500,
164
+ save_total_limit = None,
165
+ save_safetensors = True,
166
+ save_on_each_node = False,
167
+ save_only_model = False,
168
+ restore_callback_states_from_checkpoint = False,
169
+ no_cuda = False,
170
+ use_cpu = False,
171
+ use_mps_device = False,
172
+ seed = 3407,
173
+ data_seed = 3407,
174
+ jit_mode_eval = False,
175
+ use_ipex = False,
176
+ bf16 = False,
177
+ fp16 = False,
178
+ fp16_opt_level = 'O1',
179
+ half_precision_backend = 'auto',
180
+ bf16_full_eval = False,
181
+ fp16_full_eval = False,
182
+ tf32 = None,
183
+ local_rank = -1,
184
+ ddp_backend = None,
185
+ tpu_num_cores = None,
186
+ tpu_metrics_debug = False,
187
+ debug = '',
188
+ dataloader_drop_last = False,
189
+ eval_steps = None,
190
+ dataloader_num_workers = 0,
191
+ dataloader_prefetch_factor = None,
192
+ past_index = -1,
193
+ run_name = None,
194
+ disable_tqdm = None,
195
+ remove_unused_columns = True,
196
+ label_names = None,
197
+ load_best_model_at_end = False,
198
+ metric_for_best_model = None,
199
+ greater_is_better = None,
200
+ ignore_data_skip = False,
201
+ fsdp = '',
202
+ fsdp_min_num_params = 0,
203
+ fsdp_config = None,
204
+ fsdp_transformer_layer_cls_to_wrap = None,
205
+ accelerator_config = None,
206
+ deepspeed = None,
207
+ label_smoothing_factor = 0.0,
208
+ optim = 'adamw_8bit',
209
+ optim_args = None,
210
+ adafactor = False,
211
+ group_by_length = False,
212
+ length_column_name = 'length',
213
+ report_to = None,
214
+ ddp_find_unused_parameters = None,
215
+ ddp_bucket_cap_mb = None,
216
+ ddp_broadcast_buffers = None,
217
+ dataloader_pin_memory = True,
218
+ dataloader_persistent_workers = False,
219
+ skip_memory_metrics = True,
220
+ use_legacy_prediction_loop = False,
221
+ push_to_hub = False,
222
+ resume_from_checkpoint = None,
223
+ hub_model_id = None,
224
+ hub_strategy = 'every_save',
225
+ hub_token = None,
226
+ hub_private_repo = None,
227
+ hub_always_push = False,
228
+ hub_revision = None,
229
+ gradient_checkpointing = False,
230
+ gradient_checkpointing_kwargs = None,
231
+ include_inputs_for_metrics = False,
232
+ eval_do_concat_batches = True,
233
+ fp16_backend = 'auto',
234
+ push_to_hub_model_id = None,
235
+ push_to_hub_organization = None,
236
+ push_to_hub_token = None,
237
+ mp_parameters = '',
238
+ auto_find_batch_size = True,
239
+ full_determinism = False,
240
+ torchdynamo = None,
241
+ ray_scope = 'last',
242
+ ddp_timeout = 1800,
243
+ torch_compile = False,
244
+ torch_compile_backend = None,
245
+ torch_compile_mode = None,
246
+ include_tokens_per_second = False,
247
+ include_num_input_tokens_seen = False,
248
+ neftune_noise_alpha = None,
249
+ optim_target_modules = None,
250
+ batch_eval_metrics = False,
251
+ eval_on_start = False,
252
+ use_liger_kernel = False,
253
+ liger_kernel_config = None,
254
+ eval_use_gather_object = False,
255
+ average_tokens_across_devices = True,
256
+ max_length = 1024,
257
+ max_prompt_length = 512,
258
+ max_completion_length = None,
259
+ beta = 0.1,
260
+ label_smoothing = 0.0,
261
+ loss_type = 'sigmoid',
262
+ disable_dropout = True,
263
+ cpo_alpha = 1.0,
264
+ simpo_gamma = 0.5,
265
+ label_pad_token_id = -100,
266
+ padding_value = None,
267
+ truncation_mode = 'keep_end',
268
+ generate_during_eval = False,
269
+ is_encoder_decoder = None,
270
+ model_init_kwargs = None,
271
+ dataset_num_proc = None,
272
+ vllm_sampling_params = None,
273
+ unsloth_num_chunks = -1,
274
+ max_seq_length = None,
275
+ **kwargs,
276
+ ):
277
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
278
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
279
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
280
+ output_dir = 'unsloth_training_checkpoints'
281
+ save_strategy = 'no'
282
+ if dataset_num_proc is None:
283
+ from multiprocessing import cpu_count
284
+ dataset_num_proc = min(cpu_count()*2, 2)
285
+
286
+ super().__init__(
287
+ output_dir = output_dir,
288
+ overwrite_output_dir = overwrite_output_dir,
289
+ do_train = do_train,
290
+ do_eval = do_eval,
291
+ do_predict = do_predict,
292
+ eval_strategy = eval_strategy,
293
+ prediction_loss_only = prediction_loss_only,
294
+ per_device_train_batch_size = per_device_train_batch_size,
295
+ per_device_eval_batch_size = per_device_eval_batch_size,
296
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
297
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
298
+ gradient_accumulation_steps = gradient_accumulation_steps,
299
+ eval_accumulation_steps = eval_accumulation_steps,
300
+ eval_delay = eval_delay,
301
+ torch_empty_cache_steps = torch_empty_cache_steps,
302
+ learning_rate = learning_rate,
303
+ weight_decay = weight_decay,
304
+ adam_beta1 = adam_beta1,
305
+ adam_beta2 = adam_beta2,
306
+ adam_epsilon = adam_epsilon,
307
+ max_grad_norm = max_grad_norm,
308
+ num_train_epochs = num_train_epochs,
309
+ max_steps = max_steps,
310
+ lr_scheduler_type = lr_scheduler_type,
311
+ warmup_ratio = warmup_ratio,
312
+ warmup_steps = warmup_steps,
313
+ log_level = log_level,
314
+ log_level_replica = log_level_replica,
315
+ log_on_each_node = log_on_each_node,
316
+ logging_dir = logging_dir,
317
+ logging_strategy = logging_strategy,
318
+ logging_first_step = logging_first_step,
319
+ logging_steps = logging_steps,
320
+ logging_nan_inf_filter = logging_nan_inf_filter,
321
+ save_strategy = save_strategy,
322
+ save_steps = save_steps,
323
+ save_total_limit = save_total_limit,
324
+ save_safetensors = save_safetensors,
325
+ save_on_each_node = save_on_each_node,
326
+ save_only_model = save_only_model,
327
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
328
+ no_cuda = no_cuda,
329
+ use_cpu = use_cpu,
330
+ use_mps_device = use_mps_device,
331
+ seed = seed,
332
+ data_seed = data_seed,
333
+ jit_mode_eval = jit_mode_eval,
334
+ use_ipex = use_ipex,
335
+ bf16 = bf16,
336
+ fp16 = fp16,
337
+ fp16_opt_level = fp16_opt_level,
338
+ half_precision_backend = half_precision_backend,
339
+ bf16_full_eval = bf16_full_eval,
340
+ fp16_full_eval = fp16_full_eval,
341
+ tf32 = tf32,
342
+ local_rank = local_rank,
343
+ ddp_backend = ddp_backend,
344
+ tpu_num_cores = tpu_num_cores,
345
+ tpu_metrics_debug = tpu_metrics_debug,
346
+ debug = debug,
347
+ dataloader_drop_last = dataloader_drop_last,
348
+ eval_steps = eval_steps,
349
+ dataloader_num_workers = dataloader_num_workers,
350
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
351
+ past_index = past_index,
352
+ run_name = run_name,
353
+ disable_tqdm = disable_tqdm,
354
+ remove_unused_columns = remove_unused_columns,
355
+ label_names = label_names,
356
+ load_best_model_at_end = load_best_model_at_end,
357
+ metric_for_best_model = metric_for_best_model,
358
+ greater_is_better = greater_is_better,
359
+ ignore_data_skip = ignore_data_skip,
360
+ fsdp = fsdp,
361
+ fsdp_min_num_params = fsdp_min_num_params,
362
+ fsdp_config = fsdp_config,
363
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
364
+ accelerator_config = accelerator_config,
365
+ deepspeed = deepspeed,
366
+ label_smoothing_factor = label_smoothing_factor,
367
+ optim = optim,
368
+ optim_args = optim_args,
369
+ adafactor = adafactor,
370
+ group_by_length = group_by_length,
371
+ length_column_name = length_column_name,
372
+ report_to = report_to,
373
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
374
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
375
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
376
+ dataloader_pin_memory = dataloader_pin_memory,
377
+ dataloader_persistent_workers = dataloader_persistent_workers,
378
+ skip_memory_metrics = skip_memory_metrics,
379
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
380
+ push_to_hub = push_to_hub,
381
+ resume_from_checkpoint = resume_from_checkpoint,
382
+ hub_model_id = hub_model_id,
383
+ hub_strategy = hub_strategy,
384
+ hub_token = hub_token,
385
+ hub_private_repo = hub_private_repo,
386
+ hub_always_push = hub_always_push,
387
+ hub_revision = hub_revision,
388
+ gradient_checkpointing = gradient_checkpointing,
389
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
390
+ include_inputs_for_metrics = include_inputs_for_metrics,
391
+ eval_do_concat_batches = eval_do_concat_batches,
392
+ fp16_backend = fp16_backend,
393
+ push_to_hub_model_id = push_to_hub_model_id,
394
+ push_to_hub_organization = push_to_hub_organization,
395
+ push_to_hub_token = push_to_hub_token,
396
+ mp_parameters = mp_parameters,
397
+ auto_find_batch_size = auto_find_batch_size,
398
+ full_determinism = full_determinism,
399
+ torchdynamo = torchdynamo,
400
+ ray_scope = ray_scope,
401
+ ddp_timeout = ddp_timeout,
402
+ torch_compile = torch_compile,
403
+ torch_compile_backend = torch_compile_backend,
404
+ torch_compile_mode = torch_compile_mode,
405
+ include_tokens_per_second = include_tokens_per_second,
406
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
407
+ neftune_noise_alpha = neftune_noise_alpha,
408
+ optim_target_modules = optim_target_modules,
409
+ batch_eval_metrics = batch_eval_metrics,
410
+ eval_on_start = eval_on_start,
411
+ use_liger_kernel = use_liger_kernel,
412
+ liger_kernel_config = liger_kernel_config,
413
+ eval_use_gather_object = eval_use_gather_object,
414
+ average_tokens_across_devices = average_tokens_across_devices,
415
+ max_length = max_length,
416
+ max_prompt_length = max_prompt_length,
417
+ max_completion_length = max_completion_length,
418
+ beta = beta,
419
+ label_smoothing = label_smoothing,
420
+ loss_type = loss_type,
421
+ disable_dropout = disable_dropout,
422
+ cpo_alpha = cpo_alpha,
423
+ simpo_gamma = simpo_gamma,
424
+ label_pad_token_id = label_pad_token_id,
425
+ padding_value = padding_value,
426
+ truncation_mode = truncation_mode,
427
+ generate_during_eval = generate_during_eval,
428
+ is_encoder_decoder = is_encoder_decoder,
429
+ model_init_kwargs = model_init_kwargs,
430
+ dataset_num_proc = dataset_num_proc,**kwargs)
431
+ self.vllm_sampling_params = vllm_sampling_params
432
+ self.unsloth_num_chunks = unsloth_num_chunks
433
+ self.max_seq_length = max_seq_length
434
+ pass
435
+
436
+ class _UnslothCPOTrainer(Trainer):
437
+ r""""""
438
+
439
+ _tag_names = ["trl", "cpo"]
440
+
441
+ def __init__(
442
+ self,
443
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
444
+ args: Optional[CPOConfig] = None,
445
+ data_collator: Optional[DataCollator] = None,
446
+ train_dataset: Optional[Dataset] = None,
447
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
448
+ processing_class: Optional[
449
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
450
+ ] = None,
451
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
452
+ callbacks: Optional[list[TrainerCallback]] = None,
453
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
454
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
455
+ peft_config: Optional[dict] = None,
456
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
457
+ ):
458
+ if args.model_init_kwargs is None:
459
+ model_init_kwargs = {}
460
+ elif not isinstance(model, str):
461
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
462
+ else:
463
+ model_init_kwargs = args.model_init_kwargs
464
+ torch_dtype = model_init_kwargs.get("torch_dtype")
465
+ if torch_dtype is not None:
466
+ # Convert to `torch.dtype` if an str is passed
467
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
468
+ torch_dtype = getattr(torch, torch_dtype)
469
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
470
+ raise ValueError(
471
+ f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
472
+ )
473
+ model_init_kwargs["torch_dtype"] = torch_dtype
474
+
475
+ if isinstance(model, str):
476
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
477
+
478
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
479
+ # has been called in order to properly call autocast if needed.
480
+ self._peft_has_been_casted_to_bf16 = False
481
+
482
+ if not is_peft_available() and peft_config is not None:
483
+ raise ValueError(
484
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
485
+ )
486
+ elif is_peft_available() and peft_config is not None:
487
+ # if model is a peft model and we have a peft_config, we merge and unload it first
488
+ if isinstance(model, PeftModel):
489
+ model = model.merge_and_unload()
490
+
491
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
492
+ _support_gc_kwargs = hasattr(
493
+ args, "gradient_checkpointing_kwargs"
494
+ ) and "gradient_checkpointing_kwargs" in list(
495
+ inspect.signature(prepare_model_for_kbit_training).parameters
496
+ )
497
+
498
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
499
+
500
+ if _support_gc_kwargs:
501
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
502
+
503
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
504
+ elif args.gradient_checkpointing:
505
+ # For backward compatibility with older versions of transformers
506
+ if hasattr(model, "enable_input_require_grads"):
507
+ model.enable_input_require_grads()
508
+ else:
509
+
510
+ def make_inputs_require_grad(module, input, output):
511
+ output.requires_grad_(True)
512
+
513
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
514
+
515
+ # get peft model with the given config
516
+ model = model
517
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
518
+ peft_module_casting_to_bf16(model)
519
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
520
+ self._peft_has_been_casted_to_bf16 = True
521
+
522
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
523
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
524
+ # fail or completely fail.
525
+ elif args.gradient_checkpointing:
526
+ # For backward compatibility with older versions of transformers
527
+ if hasattr(model, "enable_input_require_grads"):
528
+ model.enable_input_require_grads()
529
+ else:
530
+
531
+ def make_inputs_require_grad(module, input, output):
532
+ output.requires_grad_(True)
533
+
534
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
535
+
536
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
537
+ raise ValueError(
538
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
539
+ " Please install `wandb` or `comet-ml` to resolve."
540
+ )
541
+
542
+ if model is not None:
543
+ self.is_encoder_decoder = model.config.is_encoder_decoder
544
+ elif args.is_encoder_decoder is None:
545
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
546
+ else:
547
+ self.is_encoder_decoder = args.is_encoder_decoder
548
+
549
+ if self.is_encoder_decoder:
550
+ self.decoder_start_token_id = model.config.decoder_start_token_id
551
+ self.pad_token_id = model.config.pad_token_id
552
+
553
+ if processing_class is None:
554
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
555
+ if args.max_length is None:
556
+ warnings.warn(
557
+ "`max_length` is not set in the CPOConfig's init"
558
+ " it will default to `512` by default, but you should do it yourself in the future.",
559
+ UserWarning,
560
+ )
561
+ max_length = 512
562
+ else:
563
+ max_length = args.max_length
564
+ if args.max_prompt_length is None:
565
+ warnings.warn(
566
+ "`max_prompt_length` is not set in the CPOConfig's init"
567
+ " it will default to `128` by default, but you should do it yourself in the future.",
568
+ UserWarning,
569
+ )
570
+ max_prompt_length = 128
571
+ else:
572
+ max_prompt_length = args.max_prompt_length
573
+
574
+ if not max_prompt_length < max_length:
575
+ raise ValueError(
576
+ f"max_prompt_length ({max_prompt_length}) should be strictly less than max_length ({max_length})."
577
+ )
578
+
579
+ if args.max_completion_length is None and self.is_encoder_decoder:
580
+ warnings.warn(
581
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
582
+ " it will default to `128` by default, but you should do it yourself in the future.",
583
+ UserWarning,
584
+ )
585
+ max_completion_length = 128
586
+ else:
587
+ max_completion_length = args.max_completion_length
588
+
589
+ if data_collator is None:
590
+ data_collator = DPODataCollatorWithPadding(
591
+ pad_token_id=processing_class.pad_token_id,
592
+ label_pad_token_id=args.label_pad_token_id,
593
+ is_encoder_decoder=self.is_encoder_decoder,
594
+ )
595
+
596
+ if args.remove_unused_columns:
597
+ args.remove_unused_columns = False
598
+ # warn users
599
+ warnings.warn(
600
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
601
+ " we have set it for you, but you should do it yourself in the future.",
602
+ UserWarning,
603
+ )
604
+
605
+ self.use_dpo_data_collator = True
606
+ else:
607
+ self.use_dpo_data_collator = False
608
+
609
+ # Disable dropout in the model
610
+ if args.disable_dropout:
611
+ disable_dropout_in_model(model)
612
+
613
+ self.max_length = max_length
614
+ self.generate_during_eval = args.generate_during_eval
615
+ self.label_pad_token_id = args.label_pad_token_id
616
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
617
+ self.max_prompt_length = max_prompt_length
618
+ self.truncation_mode = args.truncation_mode
619
+ self.max_completion_length = max_completion_length
620
+ self.processing_class = processing_class
621
+
622
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
623
+ warnings.warn(
624
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
625
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
626
+ UserWarning,
627
+ )
628
+ if args.loss_type == "kto_pair":
629
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
630
+
631
+ self.beta = args.beta
632
+ self.label_smoothing = args.label_smoothing
633
+ self.loss_type = args.loss_type
634
+ self.cpo_alpha = args.cpo_alpha
635
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
636
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
637
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
638
+ warnings.warn(
639
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
640
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
641
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
642
+ "loss.",
643
+ UserWarning,
644
+ )
645
+
646
+ if args.loss_type == "simpo":
647
+ self.simpo_gamma = args.simpo_gamma
648
+
649
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
650
+
651
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
652
+ # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
653
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
654
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
655
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
656
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
657
+ # that the warning has already been issued.
658
+ model.warnings_issued["estimate_tokens"] = True
659
+
660
+ # Compute that only on the main process for faster data processing.
661
+ # see: https://github.com/huggingface/trl/pull/1255
662
+ with PartialState().main_process_first():
663
+ # Extract the prompt if needed, and apply the chat template if needed
664
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
665
+ train_dataset = train_dataset.map(
666
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
667
+ )
668
+ if eval_dataset is not None:
669
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
670
+ eval_dataset = eval_dataset.map(
671
+ maybe_apply_chat_template,
672
+ fn_kwargs={"tokenizer": processing_class},
673
+ num_proc=args.dataset_num_proc,
674
+ )
675
+
676
+ # tokenize the dataset
677
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
678
+ if eval_dataset is not None:
679
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
680
+
681
+ super().__init__(
682
+ model=model,
683
+ args=args,
684
+ data_collator=data_collator,
685
+ train_dataset=train_dataset,
686
+ eval_dataset=eval_dataset,
687
+ processing_class=processing_class,
688
+ model_init=model_init,
689
+ compute_metrics=compute_metrics,
690
+ callbacks=callbacks,
691
+ optimizers=optimizers,
692
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
693
+ )
694
+
695
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
696
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
697
+ # self.model_accepts_loss_kwargs to False to enable scaling.
698
+ self.model_accepts_loss_kwargs = False
699
+
700
+ # Add tags for models that have been loaded with the correct transformers version
701
+ if hasattr(self.model, "add_model_tags"):
702
+ self.model.add_model_tags(self._tag_names)
703
+
704
+ if not hasattr(self, "accelerator"):
705
+ raise AttributeError(
706
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
707
+ )
708
+
709
+ def build_tokenized_answer(self, prompt, answer):
710
+ """
711
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
712
+ b)[len(enc(a)):]`. Reference:
713
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
714
+ """
715
+
716
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
717
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
718
+
719
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
720
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
721
+
722
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
723
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
724
+
725
+ # Prepare input tokens for token by token comparison
726
+ full_input_ids = np.array(full_tokenized["input_ids"])
727
+
728
+ if len(full_input_ids) != len(full_concat_input_ids):
729
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
730
+
731
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
732
+ # can be merged together when tokenizing prompt+answer. This could result
733
+ # on the last token from the prompt being different when tokenized on its own
734
+ # vs when done as prompt+answer.
735
+ response_token_ids_start_idx = len(prompt_input_ids)
736
+
737
+ # If tokenized prompt is different than both prompt+answer, then it means the
738
+ # last token has changed due to merging.
739
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
740
+ response_token_ids_start_idx -= 1
741
+
742
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
743
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
744
+
745
+ if len(prompt_input_ids) != len(prompt_attention_mask):
746
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
747
+
748
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
749
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
750
+
751
+ return dict(
752
+ prompt_input_ids=prompt_input_ids,
753
+ prompt_attention_mask=prompt_attention_mask,
754
+ input_ids=answer_input_ids,
755
+ attention_mask=answer_attention_mask,
756
+ )
757
+
758
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
759
+ """Tokenize a single row from a CPO specific dataset.
760
+
761
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
762
+ chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
763
+ we truncate the chosen/rejected.
764
+
765
+ We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
766
+ of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
767
+ """
768
+ batch = {}
769
+ prompt = feature["prompt"]
770
+ chosen = feature["chosen"]
771
+ rejected = feature["rejected"]
772
+
773
+ if not self.is_encoder_decoder:
774
+ # Check issues below for more details
775
+ # 1. https://github.com/huggingface/trl/issues/907
776
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
777
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
778
+
779
+ if not isinstance(prompt, str):
780
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
781
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
782
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
783
+
784
+ if not isinstance(chosen, str):
785
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
786
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
787
+
788
+ if not isinstance(rejected, str):
789
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
790
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
791
+
792
+ # Last prompt token might get merged by tokenizer and
793
+ # it should not be included for generation if that happens
794
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
795
+
796
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
797
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
798
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
799
+
800
+ for k, v in prompt_tokens.items():
801
+ prompt_tokens[k] = v[:prompt_len_input_ids]
802
+
803
+ # Make sure prompts only have one different token at most an
804
+ # and length only differs by 1 at most
805
+ num_diff_tokens = sum(
806
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
807
+ )
808
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
809
+ if num_diff_tokens > 1 or num_diff_len > 1:
810
+ raise ValueError(
811
+ "Chosen and rejected prompt_input_ids might only differ on the "
812
+ "last token due to tokenizer merge ops."
813
+ )
814
+
815
+ # add BOS token to head of prompt. Avoid adding if it's already there
816
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
817
+ self.processing_class.bos_token_id,
818
+ prompt_len_input_ids,
819
+ prompt_tokens,
820
+ chosen_prompt_len_input_ids,
821
+ chosen_tokens,
822
+ rejected_prompt_len_input_ids,
823
+ rejected_tokens,
824
+ )
825
+
826
+ # add EOS token to end of answer. Avoid adding if it's already there
827
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
828
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
829
+ )
830
+
831
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
832
+
833
+ # if combined sequence is too long, truncate the prompt
834
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
835
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
836
+ if self.truncation_mode == "keep_start":
837
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
838
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
839
+ elif self.truncation_mode == "keep_end":
840
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
841
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
842
+ else:
843
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
844
+
845
+ # if that's still too long, truncate the response
846
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
847
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
848
+ for k in ["input_ids", "attention_mask"]:
849
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
850
+
851
+ # Create labels
852
+ chosen_sequence_tokens = {
853
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
854
+ }
855
+ rejected_sequence_tokens = {
856
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
857
+ }
858
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
859
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
860
+ self.label_pad_token_id
861
+ ] * len(chosen_tokens["prompt_input_ids"])
862
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
863
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
864
+ self.label_pad_token_id
865
+ ] * len(rejected_tokens["prompt_input_ids"])
866
+
867
+ for k, toks in {
868
+ "chosen_": chosen_sequence_tokens,
869
+ "rejected_": rejected_sequence_tokens,
870
+ "": prompt_tokens,
871
+ }.items():
872
+ for type_key, tokens in toks.items():
873
+ if type_key == "token_type_ids":
874
+ continue
875
+ batch[f"{k}{type_key}"] = tokens
876
+
877
+ else:
878
+ chosen_tokens = self.processing_class(
879
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
880
+ )
881
+ rejected_tokens = self.processing_class(
882
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
883
+ )
884
+ prompt_tokens = self.processing_class(
885
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
886
+ )
887
+
888
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
889
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
890
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
891
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
892
+
893
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
894
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
895
+ labels=torch.tensor(batch["rejected_labels"])
896
+ )
897
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
898
+ labels=torch.tensor(batch["chosen_labels"])
899
+ )
900
+
901
+ return batch
902
+
903
+ @staticmethod
904
+ def concatenated_inputs(
905
+ batch: dict[str, Union[list, torch.LongTensor]],
906
+ is_encoder_decoder: bool = False,
907
+ label_pad_token_id: int = -100,
908
+ padding_value: int = 0,
909
+ device: Optional[torch.device] = None,
910
+ ) -> dict[str, torch.LongTensor]:
911
+ """Concatenate the chosen and rejected inputs into a single tensor.
912
+
913
+ Args:
914
+ batch:
915
+ A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
916
+ of shape (batch_size, sequence_length).
917
+ is_encoder_decoder:
918
+ Whether the model is an encoder-decoder model.
919
+ label_pad_token_id:
920
+ The label pad token id.
921
+ padding_value:
922
+ The padding value to use for the concatenated inputs_ids.
923
+ device:
924
+ The device for the concatenated inputs.
925
+
926
+ Returns:
927
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
928
+ """
929
+ concatenated_batch = {}
930
+
931
+ if is_encoder_decoder:
932
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
933
+ else:
934
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
935
+
936
+ for k in batch:
937
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
938
+ if "labels" in k or is_encoder_decoder:
939
+ pad_value = label_pad_token_id
940
+ elif k.endswith("_input_ids"):
941
+ pad_value = padding_value
942
+ elif k.endswith("_attention_mask"):
943
+ pad_value = 0
944
+ concatenated_key = k.replace("chosen", "concatenated")
945
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
946
+ for k in batch:
947
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
948
+ if "labels" in k or is_encoder_decoder:
949
+ pad_value = label_pad_token_id
950
+ elif k.endswith("_input_ids"):
951
+ pad_value = padding_value
952
+ elif k.endswith("_attention_mask"):
953
+ pad_value = 0
954
+ concatenated_key = k.replace("rejected", "concatenated")
955
+ concatenated_batch[concatenated_key] = torch.cat(
956
+ (
957
+ concatenated_batch[concatenated_key],
958
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
959
+ ),
960
+ dim=0,
961
+ ).to(device=device)
962
+
963
+ if is_encoder_decoder:
964
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
965
+ concatenated_batch["concatenated_attention_mask"] = (
966
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
967
+ )
968
+
969
+ return concatenated_batch
970
+
971
+ def cpo_loss(
972
+ self,
973
+ policy_chosen_logps: torch.FloatTensor,
974
+ policy_rejected_logps: torch.FloatTensor,
975
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
976
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
977
+
978
+ Args:
979
+ policy_chosen_logps:
980
+ Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
981
+ policy_rejected_logps:
982
+ Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
983
+
984
+ Returns:
985
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO
986
+ loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
987
+ the chosen and rejected responses, respectively.
988
+ """
989
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
990
+
991
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
992
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
993
+ # calculates a conservative CPO loss.
994
+
995
+ if self.loss_type == "simpo":
996
+ gamma_logratios = self.simpo_gamma / self.beta
997
+ logits = logits - gamma_logratios
998
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
999
+ losses = (
1000
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
1001
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
1002
+ )
1003
+ elif self.loss_type == "sigmoid":
1004
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
1005
+ losses = (
1006
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
1007
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
1008
+ )
1009
+ elif self.loss_type == "hinge":
1010
+ losses = torch.relu(1 - self.beta * logits)
1011
+ elif self.loss_type == "ipo":
1012
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
1013
+ losses = (logits - 1 / (2 * self.beta)) ** 2
1014
+ else:
1015
+ raise ValueError(
1016
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
1017
+ )
1018
+
1019
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
1020
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
1021
+
1022
+ return losses, chosen_rewards, rejected_rewards
1023
+
1024
+ @staticmethod
1025
+ def get_batch_logps(
1026
+ logits: torch.FloatTensor,
1027
+ labels: torch.LongTensor,
1028
+ average_log_prob: bool = False,
1029
+ label_pad_token_id: int = -100,
1030
+ is_encoder_decoder: bool = False,
1031
+ ) -> torch.FloatTensor:
1032
+ """Compute the log probabilities of the given labels under the given logits.
1033
+
1034
+ Args:
1035
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1036
+ labels:
1037
+ Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
1038
+ ignored. Shape: (batch_size, sequence_length)
1039
+ average_log_prob:
1040
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
1041
+ log probabilities of the (non-masked) tokens.
1042
+ label_pad_token_id: The label pad token id.
1043
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1044
+
1045
+ Returns:
1046
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
1047
+ given logits.
1048
+ """
1049
+ if logits.shape[:-1] != labels.shape:
1050
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1051
+
1052
+ if not is_encoder_decoder:
1053
+ labels = labels[:, 1:].clone()
1054
+ logits = logits[:, :-1, :]
1055
+ loss_mask = labels != label_pad_token_id
1056
+
1057
+ # dummy token; we'll ignore the losses on these tokens later
1058
+ labels[labels == label_pad_token_id] = 0
1059
+
1060
+ per_token_logps = selective_log_softmax(logits, labels)
1061
+
1062
+ if average_log_prob:
1063
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1064
+ else:
1065
+ return (per_token_logps * loss_mask).sum(-1)
1066
+
1067
+ def concatenated_forward(
1068
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1069
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1070
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1071
+
1072
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1073
+ """
1074
+ concatenated_batch = self.concatenated_inputs(
1075
+ batch,
1076
+ is_encoder_decoder=self.is_encoder_decoder,
1077
+ label_pad_token_id=self.label_pad_token_id,
1078
+ padding_value=self.padding_value,
1079
+ device=self.accelerator.device,
1080
+ )
1081
+ len_chosen = batch["chosen_labels"].shape[0]
1082
+
1083
+ model_kwargs = (
1084
+ {
1085
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1086
+ }
1087
+ if self.is_encoder_decoder
1088
+ else {}
1089
+ )
1090
+
1091
+ if self.aux_loss_enabled:
1092
+ model_kwargs["output_router_logits"] = True
1093
+
1094
+ outputs = model(
1095
+ concatenated_batch["concatenated_input_ids"],
1096
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1097
+ use_cache=False,
1098
+ **model_kwargs,
1099
+ )
1100
+ all_logits = outputs.logits
1101
+
1102
+ def cross_entropy_loss(logits, labels):
1103
+ if not self.is_encoder_decoder:
1104
+ # Shift so that tokens < n predict n
1105
+ logits = logits[..., :-1, :].contiguous()
1106
+ labels = labels[..., 1:].contiguous()
1107
+ # Flatten the tokens
1108
+ loss_fct = nn.CrossEntropyLoss()
1109
+ logits = logits.view(-1, logits.shape[-1])
1110
+ labels = labels.view(-1)
1111
+ # Enable model parallelism
1112
+ labels = labels.to(logits.device)
1113
+ loss = loss_fct(logits, labels)
1114
+ return loss
1115
+
1116
+ labels = concatenated_batch["concatenated_labels"].clone()
1117
+
1118
+ if self.cpo_alpha == 0:
1119
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1120
+ else:
1121
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1122
+
1123
+ all_logps = self.get_batch_logps(
1124
+ all_logits,
1125
+ concatenated_batch["concatenated_labels"],
1126
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
1127
+ is_encoder_decoder=self.is_encoder_decoder,
1128
+ label_pad_token_id=self.label_pad_token_id,
1129
+ )
1130
+
1131
+ chosen_logps = all_logps[:len_chosen]
1132
+ rejected_logps = all_logps[len_chosen:]
1133
+
1134
+ chosen_logits = all_logits[:len_chosen]
1135
+ rejected_logits = all_logits[len_chosen:]
1136
+
1137
+ if self.aux_loss_enabled:
1138
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1139
+
1140
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1141
+
1142
+ def get_batch_loss_metrics(
1143
+ self,
1144
+ model,
1145
+ batch: dict[str, Union[list, torch.LongTensor]],
1146
+ train_eval: Literal["train", "eval"] = "train",
1147
+ ):
1148
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1149
+ metrics = {}
1150
+
1151
+ forward_output = self.concatenated_forward(model, batch)
1152
+ (
1153
+ policy_chosen_logps,
1154
+ policy_rejected_logps,
1155
+ policy_chosen_logits,
1156
+ policy_rejected_logits,
1157
+ policy_nll_loss,
1158
+ ) = forward_output[:5]
1159
+ if self.aux_loss_enabled:
1160
+ aux_loss = forward_output[5]
1161
+
1162
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1163
+ policy_chosen_logps,
1164
+ policy_rejected_logps,
1165
+ )
1166
+
1167
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1168
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1169
+
1170
+ prefix = "eval_" if train_eval == "eval" else ""
1171
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1172
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1173
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1174
+ metrics[f"{prefix}rewards/margins"] = (
1175
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1176
+ )
1177
+ metrics[f"{prefix}logps/rejected"] = (
1178
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1179
+ )
1180
+ metrics[f"{prefix}logps/chosen"] = (
1181
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1182
+ )
1183
+ metrics[f"{prefix}logits/rejected"] = (
1184
+ self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item()
1185
+ )
1186
+ metrics[f"{prefix}logits/chosen"] = (
1187
+ self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item()
1188
+ )
1189
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1190
+
1191
+ if self.aux_loss_enabled:
1192
+ loss += self.aux_loss_coef * aux_loss
1193
+
1194
+ return loss, metrics
1195
+
1196
+ def compute_loss(
1197
+ self,
1198
+ model: Union[PreTrainedModel, nn.Module],
1199
+ inputs: dict[str, Union[torch.Tensor, Any]],
1200
+ return_outputs=False,
1201
+ num_items_in_batch=None,
1202
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1203
+ compute_loss_context_manager = (
1204
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1205
+ )
1206
+
1207
+ with compute_loss_context_manager:
1208
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1209
+
1210
+ # force log the metrics
1211
+ self.store_metrics(metrics, train_eval="train")
1212
+
1213
+ if return_outputs:
1214
+ return (loss, metrics)
1215
+ return loss
1216
+
1217
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1218
+ """Generate samples from the model and reference model for the given batch of inputs."""
1219
+
1220
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1221
+ # the torch amp context manager as some hidden states are silently casted to full precision.
1222
+ generate_context_manager = (
1223
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1224
+ )
1225
+
1226
+ with generate_context_manager:
1227
+ policy_output = model.generate(
1228
+ input_ids=batch["prompt_input_ids"],
1229
+ attention_mask=batch["prompt_attention_mask"],
1230
+ max_length=self.max_length,
1231
+ do_sample=True,
1232
+ pad_token_id=self.processing_class.pad_token_id,
1233
+ )
1234
+
1235
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1236
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1237
+
1238
+ return policy_output_decoded
1239
+
1240
+ def prediction_step(
1241
+ self,
1242
+ model: Union[PreTrainedModel, nn.Module],
1243
+ inputs: dict[str, Union[torch.Tensor, Any]],
1244
+ prediction_loss_only: bool,
1245
+ ignore_keys: Optional[list[str]] = None,
1246
+ ):
1247
+ if ignore_keys is None:
1248
+ if hasattr(model, "config"):
1249
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1250
+ else:
1251
+ ignore_keys = []
1252
+
1253
+ prediction_context_manager = (
1254
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1255
+ )
1256
+
1257
+ with torch.no_grad(), prediction_context_manager:
1258
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1259
+
1260
+ # force log the metrics
1261
+ self.store_metrics(metrics, train_eval="eval")
1262
+
1263
+ if prediction_loss_only:
1264
+ return (loss.detach(), None, None)
1265
+
1266
+ # logits for the chosen and rejected samples from model
1267
+ logits_dict = {
1268
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1269
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1270
+ }
1271
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1272
+ logits = torch.tensor(logits, device=self.accelerator.device)
1273
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1274
+
1275
+ return (loss.detach(), logits, labels)
1276
+
1277
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1278
+ for key, value in metrics.items():
1279
+ self._stored_metrics[train_eval][key].append(value)
1280
+
1281
+ def evaluation_loop(
1282
+ self,
1283
+ dataloader: DataLoader,
1284
+ description: str,
1285
+ prediction_loss_only: Optional[bool] = None,
1286
+ ignore_keys: Optional[list[str]] = None,
1287
+ metric_key_prefix: str = "eval",
1288
+ ) -> EvalLoopOutput:
1289
+ """
1290
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
1291
+ `Trainer.evaluate()` and `Trainer.predict()`.
1292
+
1293
+ Works both with or without labels.
1294
+ """
1295
+
1296
+ # Sample and save to game log if requested (for one batch to save time)
1297
+ if self.generate_during_eval:
1298
+ # Generate random indices within the range of the total number of samples
1299
+ num_samples = len(dataloader.dataset)
1300
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1301
+
1302
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1303
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1304
+ random_batch = self.data_collator(random_batch_dataset)
1305
+ random_batch = self._prepare_inputs(random_batch)
1306
+
1307
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1308
+
1309
+ table = pd.DataFrame(
1310
+ columns=["Prompt", "Policy"],
1311
+ data=[
1312
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1313
+ ],
1314
+ )
1315
+ if "wandb" in self.args.report_to:
1316
+ wandb.log({"game_log": wandb.Table(data=table)})
1317
+
1318
+ if "comet_ml" in self.args.report_to:
1319
+ log_table_to_comet_experiment(
1320
+ name="game_log.csv",
1321
+ table=table,
1322
+ )
1323
+
1324
+ # Base evaluation
1325
+ initial_output = super().evaluation_loop(
1326
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1327
+ )
1328
+
1329
+ return initial_output
1330
+
1331
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1332
+ """
1333
+ Log `logs` on the various objects watching training, including stored metrics.
1334
+
1335
+ Args:
1336
+ logs (`dict[str, float]`):
1337
+ The values to log.
1338
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1339
+ Start time of the training.
1340
+ """
1341
+ # logs either has 'loss' or 'eval_loss'
1342
+ train_eval = "train" if "loss" in logs else "eval"
1343
+ # Add averaged stored metrics to logs
1344
+ for key, metrics in self._stored_metrics[train_eval].items():
1345
+ logs[key] = torch.tensor(metrics).mean().item()
1346
+ del self._stored_metrics[train_eval]
1347
+ return super().log(logs, start_time)
1348
+
1349
+ def _shift_right(self, input_ids):
1350
+ if self.decoder_start_token_id is None:
1351
+ raise ValueError(
1352
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1353
+ )
1354
+
1355
+ # shift inputs to the right
1356
+ if is_torch_fx_proxy(input_ids):
1357
+ # Item assignment is not supported natively for proxies.
1358
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1359
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1360
+ else:
1361
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1362
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1363
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1364
+
1365
+ if self.pad_token_id is None:
1366
+ raise ValueError("model.config.pad_token_id has to be defined.")
1367
+ # replace possible -100 values in labels by `pad_token_id`
1368
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1369
+
1370
+ return shifted_input_ids
1371
+
1372
+ # Ensure the model card is saved along with the checkpoint
1373
+ def _save_checkpoint(self, model, trial):
1374
+ if self.args.hub_model_id is None:
1375
+ model_name = Path(self.args.output_dir).name
1376
+ else:
1377
+ model_name = self.args.hub_model_id.split("/")[-1]
1378
+ self.create_model_card(model_name=model_name)
1379
+ super()._save_checkpoint(model, trial)
1380
+
1381
+ def create_model_card(
1382
+ self,
1383
+ model_name: Optional[str] = None,
1384
+ dataset_name: Optional[str] = None,
1385
+ tags: Union[str, list[str], None] = None,
1386
+ ):
1387
+ """
1388
+ Creates a draft of a model card using the information available to the `Trainer`.
1389
+
1390
+ Args:
1391
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1392
+ Name of the model.
1393
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1394
+ Name of the dataset used for training.
1395
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1396
+ Tags to be associated with the model card.
1397
+ """
1398
+ if not self.is_world_process_zero():
1399
+ return
1400
+
1401
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1402
+ base_model = self.model.config._name_or_path
1403
+ else:
1404
+ base_model = None
1405
+
1406
+ # normalize `tags` to a mutable set
1407
+ if tags is None:
1408
+ tags = set()
1409
+ elif isinstance(tags, str):
1410
+ tags = {tags}
1411
+ else:
1412
+ tags = set(tags)
1413
+
1414
+ if hasattr(self.model.config, "unsloth_version"):
1415
+ tags.add("unsloth")
1416
+
1417
+ tags.update(self._tag_names)
1418
+
1419
+ citation = textwrap.dedent("""\
1420
+ @inproceedings{xu2024contrastive,
1421
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1422
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1423
+ year = 2024,
1424
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1425
+ publisher = {OpenReview.net},
1426
+ url = {https://openreview.net/forum?id=51iwkioZpn}
1427
+ }""")
1428
+
1429
+ model_card = generate_model_card(
1430
+ base_model=base_model,
1431
+ model_name=model_name,
1432
+ hub_model_id=self.hub_model_id,
1433
+ dataset_name=dataset_name,
1434
+ tags=tags,
1435
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1436
+ comet_url=get_comet_experiment_url(),
1437
+ trainer_name="CPO",
1438
+ trainer_citation=citation,
1439
+ paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1440
+ paper_id="2401.08417",
1441
+ )
1442
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1443
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1444
+ """
1445
+
1446
+ Initialize CPOTrainer.
1447
+
1448
+ Args:
1449
+ model (`transformers.PreTrainedModel`):
1450
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1451
+ args (`CPOConfig`):
1452
+ The CPO config arguments to use for training.
1453
+ data_collator (`transformers.DataCollator`):
1454
+ The data collator to use for training. If None is specified, the default data collator
1455
+ (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
1456
+ sequences in the batch, given a dataset of paired sequences.
1457
+ train_dataset (`datasets.Dataset`):
1458
+ The dataset to use for training.
1459
+ eval_dataset (`datasets.Dataset`):
1460
+ The dataset to use for evaluation.
1461
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
1462
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1463
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1464
+ reuse the fine-tuned model.
1465
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1466
+ The model initializer to use for training. If None is specified, the default model initializer will be
1467
+ used.
1468
+ callbacks (`list[transformers.TrainerCallback]`):
1469
+ The callbacks to use for training.
1470
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1471
+ The optimizer and scheduler to use for training.
1472
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1473
+ The function to use to preprocess the logits before computing the metrics.
1474
+ peft_config (`dict`, defaults to `None`):
1475
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
1476
+ a PEFT model.
1477
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1478
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1479
+ metric values.
1480
+
1481
+ """
1482
+ def __init__(
1483
+ self,
1484
+ model = None,
1485
+ args = None,
1486
+ data_collator = None,
1487
+ train_dataset = None,
1488
+ eval_dataset = None,
1489
+ processing_class = None,
1490
+ model_init = None,
1491
+ callbacks = None,
1492
+ preprocess_logits_for_metrics = None,
1493
+ peft_config = None,
1494
+ compute_metrics = None,
1495
+ **kwargs
1496
+ ):
1497
+ if args is None: args = UnslothCPOConfig()
1498
+ use_bf16 = getattr(args, 'bf16', False)
1499
+ if type(use_bf16) is not bool: use_bf16 = False
1500
+ use_fp16 = getattr(args, 'fp16', False)
1501
+ if type(use_fp16) is not bool: use_fp16 = False
1502
+ force_float32 = False
1503
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1504
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1505
+ force_float32 = True
1506
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1507
+ dtype = getattr(model.config, 'torch_dtype', None)
1508
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1509
+ from unsloth_zoo.utils import _get_dtype
1510
+ dtype = _get_dtype(dtype)
1511
+ float16 = dtype == torch.float16
1512
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1513
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1514
+ if force_float32:
1515
+ args.fp16 = False
1516
+ args.bf16 = False
1517
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1518
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1519
+ args.fp16 = float16
1520
+ args.bf16 = not float16
1521
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1522
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1523
+ args.eval_strategy = 'steps'
1524
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1525
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1526
+ if ga_steps is not None and ga_steps > 1:
1527
+ from transformers import __version__ as transformers_version
1528
+ if Version(transformers_version) <= Version('4.45.2'):
1529
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1530
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1531
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1532
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1533
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1534
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1535
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1536
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1537
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1538
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1539
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1540
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1541
+ if force_float32:
1542
+ args.bf16_full_eval = False
1543
+ args.fp16_full_eval = False
1544
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1545
+ args.bf16_full_eval = True
1546
+ args.fp16_full_eval = False
1547
+ elif not bf16_full_eval and not fp16_full_eval:
1548
+ args.bf16_full_eval = args.bf16
1549
+ args.fp16_full_eval = args.fp16
1550
+ _output_logits = False
1551
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1552
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1553
+ if _output_logits:
1554
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1555
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1556
+ pass
1557
+ else:
1558
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1559
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1560
+ if args_max_seq_length is None and model_max_seq_length is not None:
1561
+ max_seq_length = model.max_seq_length
1562
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1563
+ if model is not None and hasattr(model, 'for_training'):
1564
+ model.for_training()
1565
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1566
+ if 'processing_class' in locals():
1567
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1568
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1569
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1570
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1571
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1572
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1573
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1574
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1575
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1576
+ else:
1577
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1578
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1579
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1580
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1581
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1582
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1583
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1584
+ else:
1585
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1586
+ other_metrics = []
1587
+
1588
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1589
+ PatchRLStatistics('cpo_trainer', other_metrics)
1590
+
1591
+ super().__init__(
1592
+ model = model,
1593
+ args = args,
1594
+ data_collator = data_collator,
1595
+ train_dataset = train_dataset,
1596
+ eval_dataset = eval_dataset,
1597
+ processing_class = processing_class,
1598
+ model_init = model_init,
1599
+ callbacks = callbacks,
1600
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1601
+ peft_config = peft_config,
1602
+ compute_metrics = compute_metrics,**kwargs)
1603
+ if hasattr(self, 'neftune_hook_handle'):
1604
+ self.neftune_hook_handle.remove()
1605
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1606
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1607
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1608
+ pass
1609
+ if hasattr(self, 'accelerator'):
1610
+ scaler = self.accelerator.scaler
1611
+ current_model = model
1612
+ while hasattr(current_model, 'model'):
1613
+ current_model.accelerator_scaler = scaler
1614
+ current_model = current_model.model
1615
+ current_model.accelerator_scaler = scaler
1616
+ pass
1617
+
1618
+ pass
unsloth_compiled_cache/UnslothDDPOTrainer.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, Path, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warnings)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothDDPOConfig(DDPOConfig):
53
+ """
54
+
55
+ Configuration class for the [`DDPOTrainer`].
56
+
57
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
58
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
+ command line.
60
+
61
+ Parameters:
62
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
63
+ Name of this experiment (by default is the file name without the extension name).
64
+ run_name (`str`, *optional*, defaults to `""`):
65
+ Name of this run.
66
+ seed (`int`, *optional*, defaults to `0`):
67
+ Random seed.
68
+ log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
69
+ Log with either 'wandb' or 'tensorboard', check
70
+ https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
71
+ tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
72
+ Keyword arguments for the tracker (e.g. wandb_project).
73
+ accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
74
+ Keyword arguments for the accelerator.
75
+ project_kwargs (`Dict`, *optional*, defaults to `{}`):
76
+ Keyword arguments for the accelerator project config (e.g. `logging_dir`).
77
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
78
+ Name of project to use for tracking.
79
+ logdir (`str`, *optional*, defaults to `"logs"`):
80
+ Top-level logging directory for checkpoint saving.
81
+ num_epochs (`int`, *optional*, defaults to `100`):
82
+ Number of epochs to train.
83
+ save_freq (`int`, *optional*, defaults to `1`):
84
+ Number of epochs between saving model checkpoints.
85
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
86
+ Number of checkpoints to keep before overwriting old ones.
87
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
88
+ Mixed precision training.
89
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
90
+ Allow `tf32` on Ampere GPUs.
91
+ resume_from (`str`, *optional*, defaults to `""`):
92
+ Resume training from a checkpoint.
93
+ sample_num_steps (`int`, *optional*, defaults to `50`):
94
+ Number of sampler inference steps.
95
+ sample_eta (`float`, *optional*, defaults to `1.0`):
96
+ Eta parameter for the DDIM sampler.
97
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
98
+ Classifier-free guidance weight.
99
+ sample_batch_size (`int`, *optional*, defaults to `1`):
100
+ Batch size (per GPU) to use for sampling.
101
+ sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
102
+ Number of batches to sample per epoch.
103
+ train_batch_size (`int`, *optional*, defaults to `1`):
104
+ Batch size (per GPU) to use for training.
105
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
106
+ Use 8bit Adam optimizer from bitsandbytes.
107
+ train_learning_rate (`float`, *optional*, defaults to `3e-4`):
108
+ Learning rate.
109
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
110
+ Adam beta1.
111
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
112
+ Adam beta2.
113
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
114
+ Adam weight decay.
115
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
116
+ Adam epsilon.
117
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
118
+ Number of gradient accumulation steps.
119
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
120
+ Maximum gradient norm for gradient clipping.
121
+ train_num_inner_epochs (`int`, *optional*, defaults to `1`):
122
+ Number of inner epochs per outer epoch.
123
+ train_cfg (`bool`, *optional*, defaults to `True`):
124
+ Whether to use classifier-free guidance during training.
125
+ train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
126
+ Clip advantages to the range.
127
+ train_clip_range (`float`, *optional*, defaults to `1e-4`):
128
+ PPO clip range.
129
+ train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
130
+ Fraction of timesteps to train on.
131
+ per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
132
+ Whether to track statistics for each prompt separately.
133
+ per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
134
+ Number of reward values to store in the buffer for each prompt.
135
+ per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
136
+ Minimum number of reward values to store in the buffer.
137
+ async_reward_computation (`bool`, *optional*, defaults to `False`):
138
+ Whether to compute rewards asynchronously.
139
+ max_workers (`int`, *optional*, defaults to `2`):
140
+ Maximum number of workers to use for async reward computation.
141
+ negative_prompts (`str`, *optional*, defaults to `""`):
142
+ Comma-separated list of prompts to use as negative examples.
143
+ push_to_hub (`bool`, *optional*, defaults to `False`):
144
+ Whether to push the final model checkpoint to the Hub.
145
+
146
+ """
147
+ vllm_sampling_params: Optional[Any] = field(
148
+ default = None,
149
+ metadata = {'help': 'vLLM SamplingParams'},
150
+ )
151
+ unsloth_num_chunks : Optional[int] = field(
152
+ default = -1,
153
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
154
+ )
155
+
156
+ def __init__(
157
+ self,
158
+ exp_name = 'colab_kernel_launcher',
159
+ run_name = '',
160
+ seed = 3407,
161
+ log_with = None,
162
+ tracker_project_name = 'trl',
163
+ logdir = 'logs',
164
+ num_epochs = 100,
165
+ save_freq = 1,
166
+ num_checkpoint_limit = 5,
167
+ mixed_precision = 'fp16',
168
+ allow_tf32 = True,
169
+ resume_from = '',
170
+ sample_num_steps = 50,
171
+ sample_eta = 1.0,
172
+ sample_guidance_scale = 5.0,
173
+ sample_batch_size = 1,
174
+ sample_num_batches_per_epoch = 2,
175
+ train_batch_size = 1,
176
+ train_use_8bit_adam = False,
177
+ train_learning_rate = 5e-05,
178
+ train_adam_beta1 = 0.9,
179
+ train_adam_beta2 = 0.999,
180
+ train_adam_weight_decay = 0.01,
181
+ train_adam_epsilon = 1e-08,
182
+ train_gradient_accumulation_steps = 2,
183
+ train_max_grad_norm = 1.0,
184
+ train_num_inner_epochs = 1,
185
+ train_cfg = True,
186
+ train_adv_clip_max = 5.0,
187
+ train_clip_range = 0.0001,
188
+ train_timestep_fraction = 1.0,
189
+ per_prompt_stat_tracking = False,
190
+ per_prompt_stat_tracking_buffer_size = 16,
191
+ per_prompt_stat_tracking_min_count = 16,
192
+ async_reward_computation = False,
193
+ max_workers = 2,
194
+ negative_prompts = '',
195
+ push_to_hub = False,
196
+ vllm_sampling_params = None,
197
+ unsloth_num_chunks = -1,
198
+
199
+ **kwargs,
200
+ ):
201
+
202
+ super().__init__(
203
+ exp_name = exp_name,
204
+ run_name = run_name,
205
+ seed = seed,
206
+ log_with = log_with,
207
+ tracker_project_name = tracker_project_name,
208
+ logdir = logdir,
209
+ num_epochs = num_epochs,
210
+ save_freq = save_freq,
211
+ num_checkpoint_limit = num_checkpoint_limit,
212
+ mixed_precision = mixed_precision,
213
+ allow_tf32 = allow_tf32,
214
+ resume_from = resume_from,
215
+ sample_num_steps = sample_num_steps,
216
+ sample_eta = sample_eta,
217
+ sample_guidance_scale = sample_guidance_scale,
218
+ sample_batch_size = sample_batch_size,
219
+ sample_num_batches_per_epoch = sample_num_batches_per_epoch,
220
+ train_batch_size = train_batch_size,
221
+ train_use_8bit_adam = train_use_8bit_adam,
222
+ train_learning_rate = train_learning_rate,
223
+ train_adam_beta1 = train_adam_beta1,
224
+ train_adam_beta2 = train_adam_beta2,
225
+ train_adam_weight_decay = train_adam_weight_decay,
226
+ train_adam_epsilon = train_adam_epsilon,
227
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
228
+ train_max_grad_norm = train_max_grad_norm,
229
+ train_num_inner_epochs = train_num_inner_epochs,
230
+ train_cfg = train_cfg,
231
+ train_adv_clip_max = train_adv_clip_max,
232
+ train_clip_range = train_clip_range,
233
+ train_timestep_fraction = train_timestep_fraction,
234
+ per_prompt_stat_tracking = per_prompt_stat_tracking,
235
+ per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
236
+ per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
237
+ async_reward_computation = async_reward_computation,
238
+ max_workers = max_workers,
239
+ negative_prompts = negative_prompts,
240
+ push_to_hub = push_to_hub,**kwargs)
241
+ self.vllm_sampling_params = vllm_sampling_params
242
+ self.unsloth_num_chunks = unsloth_num_chunks
243
+
244
+ pass
245
+
246
+ class _UnslothDDPOTrainer(PyTorchModelHubMixin):
247
+ """"""
248
+
249
+ _tag_names = ["trl", "ddpo"]
250
+
251
+ def __init__(
252
+ self,
253
+ config: DDPOConfig,
254
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
255
+ prompt_function: Callable[[], tuple[str, Any]],
256
+ sd_pipeline: DDPOStableDiffusionPipeline,
257
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
258
+ ):
259
+ warnings.warn(
260
+ "DDPOTrainer is deprecated and will be removed in version 0.23.0.",
261
+ DeprecationWarning,
262
+ )
263
+ if image_samples_hook is None:
264
+ warnings.warn("No image_samples_hook provided; no images will be logged")
265
+
266
+ self.prompt_fn = prompt_function
267
+ self.reward_fn = reward_function
268
+ self.config = config
269
+ self.image_samples_callback = image_samples_hook
270
+
271
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
272
+
273
+ if self.config.resume_from:
274
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
275
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
276
+ # get the most recent checkpoint in this directory
277
+ checkpoints = list(
278
+ filter(
279
+ lambda x: "checkpoint_" in x,
280
+ os.listdir(self.config.resume_from),
281
+ )
282
+ )
283
+ if len(checkpoints) == 0:
284
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
285
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
286
+ self.config.resume_from = os.path.join(
287
+ self.config.resume_from,
288
+ f"checkpoint_{checkpoint_numbers[-1]}",
289
+ )
290
+
291
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
292
+
293
+ # number of timesteps within each trajectory to train on
294
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
295
+
296
+ self.accelerator = Accelerator(
297
+ log_with=self.config.log_with,
298
+ mixed_precision=self.config.mixed_precision,
299
+ project_config=accelerator_project_config,
300
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
301
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
302
+ # the total number of optimizer steps to accumulate across.
303
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
304
+ **self.config.accelerator_kwargs,
305
+ )
306
+
307
+ is_okay, message = self._config_check()
308
+ if not is_okay:
309
+ raise ValueError(message)
310
+
311
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
312
+
313
+ if self.accelerator.is_main_process:
314
+ self.accelerator.init_trackers(
315
+ self.config.tracker_project_name,
316
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
317
+ init_kwargs=self.config.tracker_kwargs,
318
+ )
319
+
320
+ logger.info(f"\n{config}")
321
+
322
+ set_seed(self.config.seed, device_specific=True)
323
+
324
+ self.sd_pipeline = sd_pipeline
325
+
326
+ self.sd_pipeline.set_progress_bar_config(
327
+ position=1,
328
+ disable=not self.accelerator.is_local_main_process,
329
+ leave=False,
330
+ desc="Timestep",
331
+ dynamic_ncols=True,
332
+ )
333
+
334
+ # For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
335
+ # as these weights are only used for inference, keeping weights in full precision is not required.
336
+ if self.accelerator.mixed_precision == "fp16":
337
+ inference_dtype = torch.float16
338
+ elif self.accelerator.mixed_precision == "bf16":
339
+ inference_dtype = torch.bfloat16
340
+ else:
341
+ inference_dtype = torch.float32
342
+
343
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
344
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
345
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
346
+
347
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
348
+
349
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
350
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
351
+
352
+ # Enable TF32 for faster training on Ampere GPUs,
353
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
354
+ if self.config.allow_tf32:
355
+ torch.backends.cuda.matmul.allow_tf32 = True
356
+
357
+ self.optimizer = self._setup_optimizer(
358
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
359
+ )
360
+
361
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
362
+ self.sd_pipeline.tokenizer(
363
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
364
+ return_tensors="pt",
365
+ padding="max_length",
366
+ truncation=True,
367
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
368
+ ).input_ids.to(self.accelerator.device)
369
+ )[0]
370
+
371
+ if config.per_prompt_stat_tracking:
372
+ self.stat_tracker = PerPromptStatTracker(
373
+ config.per_prompt_stat_tracking_buffer_size,
374
+ config.per_prompt_stat_tracking_min_count,
375
+ )
376
+
377
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
378
+ # more memory
379
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
380
+
381
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
382
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
383
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
384
+ else:
385
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
386
+
387
+ if self.config.async_reward_computation:
388
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
389
+
390
+ if config.resume_from:
391
+ logger.info(f"Resuming from {config.resume_from}")
392
+ self.accelerator.load_state(config.resume_from)
393
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
394
+ else:
395
+ self.first_epoch = 0
396
+
397
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
398
+ if not is_async:
399
+ rewards = []
400
+ for images, prompts, prompt_metadata in prompt_image_pairs:
401
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
402
+ rewards.append(
403
+ (
404
+ torch.as_tensor(reward, device=self.accelerator.device),
405
+ reward_metadata,
406
+ )
407
+ )
408
+ else:
409
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
410
+ rewards = [
411
+ (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
412
+ for reward, reward_metadata in rewards
413
+ ]
414
+
415
+ return zip(*rewards)
416
+
417
+ def step(self, epoch: int, global_step: int):
418
+ """
419
+ Perform a single step of training.
420
+
421
+ Args:
422
+ epoch (int): The current epoch.
423
+ global_step (int): The current global step.
424
+
425
+ Side Effects:
426
+ - Model weights are updated
427
+ - Logs the statistics to the accelerator trackers.
428
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
429
+ and the accelerator tracker.
430
+
431
+ Returns:
432
+ global_step (int): The updated global step.
433
+
434
+ """
435
+ samples, prompt_image_data = self._generate_samples(
436
+ iterations=self.config.sample_num_batches_per_epoch,
437
+ batch_size=self.config.sample_batch_size,
438
+ )
439
+
440
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
441
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
442
+ rewards, rewards_metadata = self.compute_rewards(
443
+ prompt_image_data, is_async=self.config.async_reward_computation
444
+ )
445
+
446
+ for i, image_data in enumerate(prompt_image_data):
447
+ image_data.extend([rewards[i], rewards_metadata[i]])
448
+
449
+ if self.image_samples_callback is not None:
450
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
451
+
452
+ rewards = torch.cat(rewards)
453
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
454
+
455
+ self.accelerator.log(
456
+ {
457
+ "reward": rewards,
458
+ "epoch": epoch,
459
+ "reward_mean": rewards.mean(),
460
+ "reward_std": rewards.std(),
461
+ },
462
+ step=global_step,
463
+ )
464
+
465
+ if self.config.per_prompt_stat_tracking:
466
+ # gather the prompts across processes
467
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
468
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
469
+ advantages = self.stat_tracker.update(prompts, rewards)
470
+ else:
471
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
472
+
473
+ # ungather advantages; keep the entries corresponding to the samples on this process
474
+ samples["advantages"] = (
475
+ torch.as_tensor(advantages)
476
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
477
+ .to(self.accelerator.device)
478
+ )
479
+
480
+ del samples["prompt_ids"]
481
+
482
+ total_batch_size, num_timesteps = samples["timesteps"].shape
483
+
484
+ for inner_epoch in range(self.config.train_num_inner_epochs):
485
+ # shuffle samples along batch dimension
486
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
487
+ samples = {k: v[perm] for k, v in samples.items()}
488
+
489
+ # shuffle along time dimension independently for each sample
490
+ # still trying to understand the code below
491
+ perms = torch.stack(
492
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
493
+ )
494
+
495
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
496
+ samples[key] = samples[key][
497
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
498
+ perms,
499
+ ]
500
+
501
+ original_keys = samples.keys()
502
+ original_values = samples.values()
503
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
504
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
505
+
506
+ # Transpose the list of original values
507
+ transposed_values = zip(*reshaped_values)
508
+ # Create new dictionaries for each row of transposed values
509
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
510
+
511
+ self.sd_pipeline.unet.train()
512
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
513
+ # ensure optimization step at the end of the inner epoch
514
+ if not self.accelerator.sync_gradients:
515
+ raise ValueError(
516
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
517
+ )
518
+
519
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
520
+ self.accelerator.save_state()
521
+
522
+ return global_step
523
+
524
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
525
+ """
526
+ Calculate the loss for a batch of an unpacked sample
527
+
528
+ Args:
529
+ latents (torch.Tensor):
530
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
531
+ timesteps (torch.Tensor):
532
+ The timesteps sampled from the diffusion model, shape: [batch_size]
533
+ next_latents (torch.Tensor):
534
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height,
535
+ width]
536
+ log_probs (torch.Tensor):
537
+ The log probabilities of the latents, shape: [batch_size]
538
+ advantages (torch.Tensor):
539
+ The advantages of the latents, shape: [batch_size]
540
+ embeds (torch.Tensor):
541
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if
542
+ train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
543
+
544
+ Returns:
545
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,))
546
+ """
547
+ with self.autocast():
548
+ if self.config.train_cfg:
549
+ noise_pred = self.sd_pipeline.unet(
550
+ torch.cat([latents] * 2),
551
+ torch.cat([timesteps] * 2),
552
+ embeds,
553
+ ).sample
554
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
555
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
556
+ noise_pred_text - noise_pred_uncond
557
+ )
558
+ else:
559
+ noise_pred = self.sd_pipeline.unet(
560
+ latents,
561
+ timesteps,
562
+ embeds,
563
+ ).sample
564
+ # compute the log prob of next_latents given latents under the current model
565
+
566
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
567
+ noise_pred,
568
+ timesteps,
569
+ latents,
570
+ eta=self.config.sample_eta,
571
+ prev_sample=next_latents,
572
+ )
573
+
574
+ log_prob = scheduler_step_output.log_probs
575
+
576
+ advantages = torch.clamp(
577
+ advantages,
578
+ -self.config.train_adv_clip_max,
579
+ self.config.train_adv_clip_max,
580
+ )
581
+
582
+ ratio = torch.exp(log_prob - log_probs)
583
+
584
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
585
+
586
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
587
+
588
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
589
+
590
+ return loss, approx_kl, clipfrac
591
+
592
+ def loss(
593
+ self,
594
+ advantages: torch.Tensor,
595
+ clip_range: float,
596
+ ratio: torch.Tensor,
597
+ ):
598
+ unclipped_loss = -advantages * ratio
599
+ clipped_loss = -advantages * torch.clamp(
600
+ ratio,
601
+ 1.0 - clip_range,
602
+ 1.0 + clip_range,
603
+ )
604
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
605
+
606
+ def _setup_optimizer(self, trainable_layers_parameters):
607
+ if self.config.train_use_8bit_adam:
608
+ import bitsandbytes
609
+
610
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
611
+ else:
612
+ optimizer_cls = torch.optim.AdamW
613
+
614
+ return optimizer_cls(
615
+ trainable_layers_parameters,
616
+ lr=self.config.train_learning_rate,
617
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
618
+ weight_decay=self.config.train_adam_weight_decay,
619
+ eps=self.config.train_adam_epsilon,
620
+ )
621
+
622
+ def _save_model_hook(self, models, weights, output_dir):
623
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
624
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
625
+
626
+ def _load_model_hook(self, models, input_dir):
627
+ self.sd_pipeline.load_checkpoint(models, input_dir)
628
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
629
+
630
+ def _generate_samples(self, iterations, batch_size):
631
+ """
632
+ Generate samples from the model
633
+
634
+ Args:
635
+ iterations (int): Number of iterations to generate samples for
636
+ batch_size (int): Batch size to use for sampling
637
+
638
+ Returns:
639
+ samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
640
+ """
641
+ samples = []
642
+ prompt_image_pairs = []
643
+ self.sd_pipeline.unet.eval()
644
+
645
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
646
+
647
+ for _ in range(iterations):
648
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
649
+
650
+ prompt_ids = self.sd_pipeline.tokenizer(
651
+ prompts,
652
+ return_tensors="pt",
653
+ padding="max_length",
654
+ truncation=True,
655
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
656
+ ).input_ids.to(self.accelerator.device)
657
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
658
+
659
+ with self.autocast():
660
+ sd_output = self.sd_pipeline(
661
+ prompt_embeds=prompt_embeds,
662
+ negative_prompt_embeds=sample_neg_prompt_embeds,
663
+ num_inference_steps=self.config.sample_num_steps,
664
+ guidance_scale=self.config.sample_guidance_scale,
665
+ eta=self.config.sample_eta,
666
+ output_type="pt",
667
+ )
668
+
669
+ images = sd_output.images
670
+ latents = sd_output.latents
671
+ log_probs = sd_output.log_probs
672
+
673
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
674
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
675
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
676
+
677
+ samples.append(
678
+ {
679
+ "prompt_ids": prompt_ids,
680
+ "prompt_embeds": prompt_embeds,
681
+ "timesteps": timesteps,
682
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
683
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
684
+ "log_probs": log_probs,
685
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
686
+ }
687
+ )
688
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
689
+
690
+ return samples, prompt_image_pairs
691
+
692
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
693
+ """
694
+ Train on a batch of samples. Main training segment
695
+
696
+ Args:
697
+ inner_epoch (int): The current inner epoch
698
+ epoch (int): The current epoch
699
+ global_step (int): The current global step
700
+ batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
701
+
702
+ Side Effects:
703
+ - Model weights are updated
704
+ - Logs the statistics to the accelerator trackers.
705
+
706
+ Returns:
707
+ global_step (int): The updated global step
708
+ """
709
+ info = defaultdict(list)
710
+ for _i, sample in enumerate(batched_samples):
711
+ if self.config.train_cfg:
712
+ # concat negative prompts to sample prompts to avoid two forward passes
713
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
714
+ else:
715
+ embeds = sample["prompt_embeds"]
716
+
717
+ for j in range(self.num_train_timesteps):
718
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
719
+ loss, approx_kl, clipfrac = self.calculate_loss(
720
+ sample["latents"][:, j],
721
+ sample["timesteps"][:, j],
722
+ sample["next_latents"][:, j],
723
+ sample["log_probs"][:, j],
724
+ sample["advantages"],
725
+ embeds,
726
+ )
727
+ info["approx_kl"].append(approx_kl)
728
+ info["clipfrac"].append(clipfrac)
729
+ info["loss"].append(loss)
730
+
731
+ self.accelerator.backward(loss)
732
+ if self.accelerator.sync_gradients:
733
+ self.accelerator.clip_grad_norm_(
734
+ self.trainable_layers.parameters()
735
+ if not isinstance(self.trainable_layers, list)
736
+ else self.trainable_layers,
737
+ self.config.train_max_grad_norm,
738
+ )
739
+ self.optimizer.step()
740
+ self.optimizer.zero_grad()
741
+
742
+ # Checks if the accelerator has performed an optimization step behind the scenes
743
+ if self.accelerator.sync_gradients:
744
+ # log training-related stuff
745
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
746
+ info = self.accelerator.reduce(info, reduction="mean")
747
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
748
+ self.accelerator.log(info, step=global_step)
749
+ global_step += 1
750
+ info = defaultdict(list)
751
+ return global_step
752
+
753
+ def _config_check(self) -> tuple[bool, str]:
754
+ samples_per_epoch = (
755
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
756
+ )
757
+ total_train_batch_size = (
758
+ self.config.train_batch_size
759
+ * self.accelerator.num_processes
760
+ * self.config.train_gradient_accumulation_steps
761
+ )
762
+
763
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
764
+ return (
765
+ False,
766
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
767
+ )
768
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
769
+ return (
770
+ False,
771
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
772
+ )
773
+ if not samples_per_epoch % total_train_batch_size == 0:
774
+ return (
775
+ False,
776
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
777
+ )
778
+ return True, ""
779
+
780
+ def train(self, epochs: Optional[int] = None):
781
+ """
782
+ Train the model for a given number of epochs
783
+ """
784
+ global_step = 0
785
+ if epochs is None:
786
+ epochs = self.config.num_epochs
787
+ for epoch in range(self.first_epoch, epochs):
788
+ global_step = self.step(epoch, global_step)
789
+
790
+ def _save_pretrained(self, save_directory):
791
+ self.sd_pipeline.save_pretrained(save_directory)
792
+ self.create_model_card()
793
+
794
+ # Ensure the model card is saved along with the checkpoint
795
+ def _save_checkpoint(self, model, trial):
796
+ if self.args.hub_model_id is None:
797
+ model_name = Path(self.args.output_dir).name
798
+ else:
799
+ model_name = self.args.hub_model_id.split("/")[-1]
800
+ self.create_model_card(model_name=model_name)
801
+ super()._save_checkpoint(model, trial)
802
+
803
+ def create_model_card(
804
+ self,
805
+ model_name: Optional[str] = None,
806
+ dataset_name: Optional[str] = None,
807
+ tags: Union[str, list[str], None] = None,
808
+ ):
809
+ """
810
+ Creates a draft of a model card using the information available to the `Trainer`.
811
+
812
+ Args:
813
+ model_name (`str` or `None`, *optional*, defaults to `None`):
814
+ Name of the model.
815
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
816
+ Name of the dataset used for training.
817
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
818
+ Tags to be associated with the model card.
819
+ """
820
+ if not self.is_world_process_zero():
821
+ return
822
+
823
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
824
+ base_model = self.model.config._name_or_path
825
+ else:
826
+ base_model = None
827
+
828
+ # normalize `tags` to a mutable set
829
+ if tags is None:
830
+ tags = set()
831
+ elif isinstance(tags, str):
832
+ tags = {tags}
833
+ else:
834
+ tags = set(tags)
835
+
836
+ if hasattr(self.model.config, "unsloth_version"):
837
+ tags.add("unsloth")
838
+
839
+ tags.update(self._tag_names)
840
+
841
+ citation = textwrap.dedent("""\
842
+ @inproceedings{black2024training,
843
+ title = {{Training Diffusion Models with Reinforcement Learning}},
844
+ author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
845
+ year = 2024,
846
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
847
+ publisher = {OpenReview.net},
848
+ url = {https://openreview.net/forum?id=YCWjhGrJFD},
849
+ }""")
850
+
851
+ model_card = generate_model_card(
852
+ base_model=base_model,
853
+ model_name=model_name,
854
+ hub_model_id=self.hub_model_id,
855
+ dataset_name=dataset_name,
856
+ tags=tags,
857
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
858
+ comet_url=get_comet_experiment_url(),
859
+ trainer_name="DDPO",
860
+ trainer_citation=citation,
861
+ paper_title="Training Diffusion Models with Reinforcement Learning",
862
+ paper_id="2305.13301",
863
+ )
864
+
865
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
866
+ class UnslothDDPOTrainer(_UnslothDDPOTrainer):
867
+ """
868
+
869
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily
870
+ inspired by the work here: https://github.com/kvablack/ddpo-pytorch As of now only Stable Diffusion based pipelines
871
+ are supported
872
+
873
+ Attributes:
874
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more:
875
+ details.
876
+ **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used:
877
+ **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
878
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
879
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
880
+
881
+ """
882
+ def __init__(
883
+ self,
884
+ config,
885
+ reward_function,
886
+ prompt_function,
887
+ sd_pipeline,
888
+ image_samples_hook = None,
889
+ **kwargs
890
+ ):
891
+ if args is None: args = UnslothDDPOConfig()
892
+ other_metrics = []
893
+
894
+ from unsloth_zoo.logging_utils import PatchRLStatistics
895
+ PatchRLStatistics('ddpo_trainer', other_metrics)
896
+
897
+ super().__init__(
898
+ config = config,
899
+ reward_function = reward_function,
900
+ prompt_function = prompt_function,
901
+ sd_pipeline = sd_pipeline,
902
+ image_samples_hook = image_samples_hook,**kwargs)
903
+
904
+ pass
905
+
906
+
907
+ if hasattr(logger, "addFilter"):
908
+ import logging
909
+ class HideLoggingMessage(logging.Filter):
910
+ def __init__(self, text): self.text = text
911
+ def filter(self, x): return not (self.text in x.getMessage())
912
+ pass
913
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
914
+
unsloth_compiled_cache/UnslothDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGKDTrainer.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation, wandb)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothGKDConfig(GKDConfig):
53
+ """
54
+
55
+ Configuration class for [`GKDTrainer`].
56
+
57
+ This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
59
+
60
+ Args:
61
+ temperature (`float`, *optional*, defaults to `0.9`):
62
+ Temperature for sampling. The higher the temperature, the more random the completions.
63
+ lmbda (`float`, *optional*, defaults to `0.5`):
64
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
65
+ student-generated outputs).
66
+ beta (`float`, *optional*, defaults to `0.5`):
67
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
68
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
69
+ max_new_tokens (`int`, *optional*, defaults to `128`):
70
+ Maximum number of tokens to generate per completion.
71
+ teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
72
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
73
+ trained.
74
+ teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
75
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
76
+ from a string.
77
+ disable_dropout (`bool`, *optional*, defaults to `True`):
78
+ Whether to disable dropout in the model.
79
+ seq_kd (`bool`, *optional*, defaults to `False`):
80
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
81
+ teacher-generated output).
82
+
83
+ """
84
+ vllm_sampling_params: Optional[Any] = field(
85
+ default = None,
86
+ metadata = {'help': 'vLLM SamplingParams'},
87
+ )
88
+ unsloth_num_chunks : Optional[int] = field(
89
+ default = -1,
90
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
91
+ )
92
+ max_seq_length : Optional[int] = field(
93
+ default = None,
94
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
95
+ )
96
+ def __init__(
97
+ self,
98
+ output_dir = None,
99
+ overwrite_output_dir = None,
100
+ do_train = False,
101
+ do_eval = False,
102
+ do_predict = False,
103
+ eval_strategy = 'no',
104
+ prediction_loss_only = False,
105
+ per_device_train_batch_size = 4,
106
+ per_device_eval_batch_size = 4,
107
+ per_gpu_train_batch_size = None,
108
+ per_gpu_eval_batch_size = None,
109
+ gradient_accumulation_steps = 2,
110
+ eval_accumulation_steps = 2,
111
+ eval_delay = 0,
112
+ torch_empty_cache_steps = 250,
113
+ learning_rate = 5e-05,
114
+ weight_decay = 0.01,
115
+ adam_beta1 = 0.9,
116
+ adam_beta2 = 0.999,
117
+ adam_epsilon = 1e-08,
118
+ max_grad_norm = 1.0,
119
+ num_train_epochs = 3.0,
120
+ max_steps = -1,
121
+ lr_scheduler_type = 'linear',
122
+ warmup_ratio = 0.1,
123
+ warmup_steps = 0,
124
+ log_level = 'passive',
125
+ log_level_replica = 'warning',
126
+ log_on_each_node = True,
127
+ logging_dir = None,
128
+ logging_strategy = 'steps',
129
+ logging_first_step = False,
130
+ logging_steps = 1,
131
+ logging_nan_inf_filter = False,
132
+ save_strategy = 'steps',
133
+ save_steps = 500,
134
+ save_total_limit = None,
135
+ save_safetensors = True,
136
+ save_on_each_node = False,
137
+ save_only_model = False,
138
+ restore_callback_states_from_checkpoint = False,
139
+ no_cuda = False,
140
+ use_cpu = False,
141
+ use_mps_device = False,
142
+ seed = 3407,
143
+ data_seed = 3407,
144
+ jit_mode_eval = False,
145
+ use_ipex = False,
146
+ bf16 = False,
147
+ fp16 = False,
148
+ fp16_opt_level = 'O1',
149
+ half_precision_backend = 'auto',
150
+ bf16_full_eval = False,
151
+ fp16_full_eval = False,
152
+ tf32 = None,
153
+ local_rank = -1,
154
+ ddp_backend = None,
155
+ tpu_num_cores = None,
156
+ tpu_metrics_debug = False,
157
+ debug = '',
158
+ dataloader_drop_last = False,
159
+ eval_steps = None,
160
+ dataloader_num_workers = 0,
161
+ dataloader_prefetch_factor = None,
162
+ past_index = -1,
163
+ run_name = None,
164
+ disable_tqdm = None,
165
+ remove_unused_columns = True,
166
+ label_names = None,
167
+ load_best_model_at_end = False,
168
+ metric_for_best_model = None,
169
+ greater_is_better = None,
170
+ ignore_data_skip = False,
171
+ fsdp = '',
172
+ fsdp_min_num_params = 0,
173
+ fsdp_config = None,
174
+ fsdp_transformer_layer_cls_to_wrap = None,
175
+ accelerator_config = None,
176
+ deepspeed = None,
177
+ label_smoothing_factor = 0.0,
178
+ optim = 'adamw_8bit',
179
+ optim_args = None,
180
+ adafactor = False,
181
+ group_by_length = False,
182
+ length_column_name = 'length',
183
+ report_to = None,
184
+ ddp_find_unused_parameters = None,
185
+ ddp_bucket_cap_mb = None,
186
+ ddp_broadcast_buffers = None,
187
+ dataloader_pin_memory = True,
188
+ dataloader_persistent_workers = False,
189
+ skip_memory_metrics = True,
190
+ use_legacy_prediction_loop = False,
191
+ push_to_hub = False,
192
+ resume_from_checkpoint = None,
193
+ hub_model_id = None,
194
+ hub_strategy = 'every_save',
195
+ hub_token = None,
196
+ hub_private_repo = None,
197
+ hub_always_push = False,
198
+ hub_revision = None,
199
+ gradient_checkpointing = False,
200
+ gradient_checkpointing_kwargs = None,
201
+ include_inputs_for_metrics = False,
202
+ eval_do_concat_batches = True,
203
+ fp16_backend = 'auto',
204
+ push_to_hub_model_id = None,
205
+ push_to_hub_organization = None,
206
+ push_to_hub_token = None,
207
+ mp_parameters = '',
208
+ auto_find_batch_size = True,
209
+ full_determinism = False,
210
+ torchdynamo = None,
211
+ ray_scope = 'last',
212
+ ddp_timeout = 1800,
213
+ torch_compile = False,
214
+ torch_compile_backend = None,
215
+ torch_compile_mode = None,
216
+ include_tokens_per_second = False,
217
+ include_num_input_tokens_seen = False,
218
+ neftune_noise_alpha = None,
219
+ optim_target_modules = None,
220
+ batch_eval_metrics = False,
221
+ eval_on_start = False,
222
+ use_liger_kernel = False,
223
+ liger_kernel_config = None,
224
+ eval_use_gather_object = False,
225
+ average_tokens_across_devices = True,
226
+ model_init_kwargs = None,
227
+ chat_template_path = None,
228
+ dataset_text_field = 'text',
229
+ dataset_kwargs = None,
230
+ dataset_num_proc = None,
231
+ eos_token = None,
232
+ pad_token = None,
233
+ max_length = 1024,
234
+ packing = False,
235
+ packing_strategy = 'bfd',
236
+ padding_free = False,
237
+ pad_to_multiple_of = None,
238
+ eval_packing = None,
239
+ completion_only_loss = None,
240
+ assistant_only_loss = False,
241
+ activation_offloading = False,
242
+ temperature = 0.9,
243
+ lmbda = 0.5,
244
+ beta = 0.5,
245
+ max_new_tokens = 128,
246
+ teacher_model_name_or_path = None,
247
+ teacher_model_init_kwargs = None,
248
+ disable_dropout = True,
249
+ seq_kd = False,
250
+ vllm_sampling_params = None,
251
+ unsloth_num_chunks = -1,
252
+ max_seq_length = None,
253
+ **kwargs,
254
+ ):
255
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
256
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
257
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
258
+ output_dir = 'unsloth_training_checkpoints'
259
+ save_strategy = 'no'
260
+ if dataset_num_proc is None:
261
+ from multiprocessing import cpu_count
262
+ dataset_num_proc = min(cpu_count()*2, 2)
263
+ if temperature <= 0:
264
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
265
+ elif temperature >= 10:
266
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
267
+
268
+
269
+ super().__init__(
270
+ output_dir = output_dir,
271
+ overwrite_output_dir = overwrite_output_dir,
272
+ do_train = do_train,
273
+ do_eval = do_eval,
274
+ do_predict = do_predict,
275
+ eval_strategy = eval_strategy,
276
+ prediction_loss_only = prediction_loss_only,
277
+ per_device_train_batch_size = per_device_train_batch_size,
278
+ per_device_eval_batch_size = per_device_eval_batch_size,
279
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
280
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
281
+ gradient_accumulation_steps = gradient_accumulation_steps,
282
+ eval_accumulation_steps = eval_accumulation_steps,
283
+ eval_delay = eval_delay,
284
+ torch_empty_cache_steps = torch_empty_cache_steps,
285
+ learning_rate = learning_rate,
286
+ weight_decay = weight_decay,
287
+ adam_beta1 = adam_beta1,
288
+ adam_beta2 = adam_beta2,
289
+ adam_epsilon = adam_epsilon,
290
+ max_grad_norm = max_grad_norm,
291
+ num_train_epochs = num_train_epochs,
292
+ max_steps = max_steps,
293
+ lr_scheduler_type = lr_scheduler_type,
294
+ warmup_ratio = warmup_ratio,
295
+ warmup_steps = warmup_steps,
296
+ log_level = log_level,
297
+ log_level_replica = log_level_replica,
298
+ log_on_each_node = log_on_each_node,
299
+ logging_dir = logging_dir,
300
+ logging_strategy = logging_strategy,
301
+ logging_first_step = logging_first_step,
302
+ logging_steps = logging_steps,
303
+ logging_nan_inf_filter = logging_nan_inf_filter,
304
+ save_strategy = save_strategy,
305
+ save_steps = save_steps,
306
+ save_total_limit = save_total_limit,
307
+ save_safetensors = save_safetensors,
308
+ save_on_each_node = save_on_each_node,
309
+ save_only_model = save_only_model,
310
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
311
+ no_cuda = no_cuda,
312
+ use_cpu = use_cpu,
313
+ use_mps_device = use_mps_device,
314
+ seed = seed,
315
+ data_seed = data_seed,
316
+ jit_mode_eval = jit_mode_eval,
317
+ use_ipex = use_ipex,
318
+ bf16 = bf16,
319
+ fp16 = fp16,
320
+ fp16_opt_level = fp16_opt_level,
321
+ half_precision_backend = half_precision_backend,
322
+ bf16_full_eval = bf16_full_eval,
323
+ fp16_full_eval = fp16_full_eval,
324
+ tf32 = tf32,
325
+ local_rank = local_rank,
326
+ ddp_backend = ddp_backend,
327
+ tpu_num_cores = tpu_num_cores,
328
+ tpu_metrics_debug = tpu_metrics_debug,
329
+ debug = debug,
330
+ dataloader_drop_last = dataloader_drop_last,
331
+ eval_steps = eval_steps,
332
+ dataloader_num_workers = dataloader_num_workers,
333
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
334
+ past_index = past_index,
335
+ run_name = run_name,
336
+ disable_tqdm = disable_tqdm,
337
+ remove_unused_columns = remove_unused_columns,
338
+ label_names = label_names,
339
+ load_best_model_at_end = load_best_model_at_end,
340
+ metric_for_best_model = metric_for_best_model,
341
+ greater_is_better = greater_is_better,
342
+ ignore_data_skip = ignore_data_skip,
343
+ fsdp = fsdp,
344
+ fsdp_min_num_params = fsdp_min_num_params,
345
+ fsdp_config = fsdp_config,
346
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
347
+ accelerator_config = accelerator_config,
348
+ deepspeed = deepspeed,
349
+ label_smoothing_factor = label_smoothing_factor,
350
+ optim = optim,
351
+ optim_args = optim_args,
352
+ adafactor = adafactor,
353
+ group_by_length = group_by_length,
354
+ length_column_name = length_column_name,
355
+ report_to = report_to,
356
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
357
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
358
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
359
+ dataloader_pin_memory = dataloader_pin_memory,
360
+ dataloader_persistent_workers = dataloader_persistent_workers,
361
+ skip_memory_metrics = skip_memory_metrics,
362
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
363
+ push_to_hub = push_to_hub,
364
+ resume_from_checkpoint = resume_from_checkpoint,
365
+ hub_model_id = hub_model_id,
366
+ hub_strategy = hub_strategy,
367
+ hub_token = hub_token,
368
+ hub_private_repo = hub_private_repo,
369
+ hub_always_push = hub_always_push,
370
+ hub_revision = hub_revision,
371
+ gradient_checkpointing = gradient_checkpointing,
372
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
373
+ include_inputs_for_metrics = include_inputs_for_metrics,
374
+ eval_do_concat_batches = eval_do_concat_batches,
375
+ fp16_backend = fp16_backend,
376
+ push_to_hub_model_id = push_to_hub_model_id,
377
+ push_to_hub_organization = push_to_hub_organization,
378
+ push_to_hub_token = push_to_hub_token,
379
+ mp_parameters = mp_parameters,
380
+ auto_find_batch_size = auto_find_batch_size,
381
+ full_determinism = full_determinism,
382
+ torchdynamo = torchdynamo,
383
+ ray_scope = ray_scope,
384
+ ddp_timeout = ddp_timeout,
385
+ torch_compile = torch_compile,
386
+ torch_compile_backend = torch_compile_backend,
387
+ torch_compile_mode = torch_compile_mode,
388
+ include_tokens_per_second = include_tokens_per_second,
389
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
390
+ neftune_noise_alpha = neftune_noise_alpha,
391
+ optim_target_modules = optim_target_modules,
392
+ batch_eval_metrics = batch_eval_metrics,
393
+ eval_on_start = eval_on_start,
394
+ use_liger_kernel = use_liger_kernel,
395
+ liger_kernel_config = liger_kernel_config,
396
+ eval_use_gather_object = eval_use_gather_object,
397
+ average_tokens_across_devices = average_tokens_across_devices,
398
+ model_init_kwargs = model_init_kwargs,
399
+ chat_template_path = chat_template_path,
400
+ dataset_text_field = dataset_text_field,
401
+ dataset_kwargs = dataset_kwargs,
402
+ dataset_num_proc = dataset_num_proc,
403
+ eos_token = eos_token,
404
+ pad_token = pad_token,
405
+ max_length = max_length,
406
+ packing = packing,
407
+ packing_strategy = packing_strategy,
408
+ padding_free = padding_free,
409
+ pad_to_multiple_of = pad_to_multiple_of,
410
+ eval_packing = eval_packing,
411
+ completion_only_loss = completion_only_loss,
412
+ assistant_only_loss = assistant_only_loss,
413
+ activation_offloading = activation_offloading,
414
+ temperature = temperature,
415
+ lmbda = lmbda,
416
+ beta = beta,
417
+ max_new_tokens = max_new_tokens,
418
+ teacher_model_name_or_path = teacher_model_name_or_path,
419
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
420
+ disable_dropout = disable_dropout,
421
+ seq_kd = seq_kd,**kwargs)
422
+ self.vllm_sampling_params = vllm_sampling_params
423
+ self.unsloth_num_chunks = unsloth_num_chunks
424
+ self.max_seq_length = max_seq_length
425
+ pass
426
+
427
+ class _UnslothGKDTrainer(SFTTrainer):
428
+ _tag_names = ["trl", "gkd"]
429
+
430
+ def __init__(
431
+ self,
432
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
433
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
434
+ args: Optional[GKDConfig] = None,
435
+ data_collator: Optional[DataCollator] = None, # type: ignore
436
+ train_dataset: Optional[Dataset] = None,
437
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
438
+ processing_class: Optional[
439
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
440
+ ] = None,
441
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
442
+ callbacks: Optional[list[TrainerCallback]] = None,
443
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
444
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
445
+ peft_config: Optional["PeftConfig"] = None,
446
+ formatting_func: Optional[Callable] = None,
447
+ ):
448
+ # add remove_unused_columns=False to the dataclass args
449
+ args.remove_unused_columns = False
450
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
451
+
452
+ super().__init__(
453
+ model,
454
+ args=args,
455
+ data_collator=data_collator,
456
+ train_dataset=train_dataset,
457
+ eval_dataset=eval_dataset,
458
+ processing_class=processing_class,
459
+ compute_metrics=compute_metrics,
460
+ callbacks=callbacks,
461
+ optimizers=optimizers,
462
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
463
+ peft_config=peft_config,
464
+ formatting_func=formatting_func,
465
+ )
466
+
467
+ if args.teacher_model_init_kwargs is None:
468
+ teacher_model_init_kwargs = {}
469
+ elif not isinstance(teacher_model, str):
470
+ raise ValueError(
471
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
472
+ )
473
+ else:
474
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
475
+ teacher_model_init_kwargs["torch_dtype"] = (
476
+ teacher_model_init_kwargs["torch_dtype"]
477
+ if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
478
+ else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
479
+ )
480
+
481
+ if isinstance(teacher_model, str):
482
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
483
+
484
+ # Disable dropout in the model
485
+ if args.disable_dropout:
486
+ disable_dropout_in_model(self.model)
487
+
488
+ if self.is_deepspeed_enabled:
489
+ self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
490
+ else:
491
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
492
+
493
+ self.lmbda = args.lmbda
494
+ self.beta = args.beta
495
+ self.temperature = args.temperature
496
+ self.seq_kd = args.seq_kd
497
+
498
+ self.generation_config = GenerationConfig(
499
+ max_new_tokens=args.max_new_tokens,
500
+ temperature=args.temperature,
501
+ do_sample=True,
502
+ top_k=0,
503
+ use_cache=False if args.gradient_checkpointing else True,
504
+ pad_token_id=self.processing_class.pad_token_id,
505
+ )
506
+ # Set custom EOS tokens if they are specified by the model's generation
507
+ # config. This is important for models with the Llama 3 chat template,
508
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
509
+ # turns or messages.
510
+ if (
511
+ hasattr(self.model.generation_config, "eos_token_id")
512
+ and self.model.generation_config.eos_token_id is not None
513
+ ):
514
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
515
+
516
+ @staticmethod
517
+ def generalized_jsd_loss(
518
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
519
+ ):
520
+ """
521
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
522
+ of https://huggingface.co/papers/2306.13649 for the definition.
523
+
524
+ Args:
525
+ student_logits:
526
+ Tensor of shape (batch_size, sequence_length, vocab_size)
527
+ teacher_logits:
528
+ Tensor of shape (batch_size, sequence_length, vocab_size)
529
+ labels:
530
+ Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
531
+ loss
532
+ beta:
533
+ Interpolation coefficient between 0 and 1 (default: 0.5)
534
+ temperature:
535
+ Softmax temperature (default: 1.0)
536
+ reduction:
537
+ Specifies the reduction to apply to the output (default: 'batchmean')
538
+
539
+ Returns:
540
+ loss: Scalar tensor with the generalized JSD loss
541
+ """
542
+
543
+ # Apply temperature scaling
544
+ student_logits = student_logits / temperature
545
+ teacher_logits = teacher_logits / temperature
546
+
547
+ # Compute log probabilities for student and probabilities for teacher
548
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
549
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
550
+
551
+ if beta == 0:
552
+ jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
553
+ elif beta == 1:
554
+ jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
555
+ else:
556
+ # Compute the log of the mixture distribution
557
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
558
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
559
+ mixture_log_probs = torch.logsumexp(
560
+ torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
561
+ dim=0,
562
+ )
563
+
564
+ # Compute KL divergences using F.kl_div
565
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
566
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
567
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
568
+
569
+ # Compute the Generalized Jensen-Shannon Divergence
570
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
571
+
572
+ # Masking
573
+ if labels is not None:
574
+ mask = labels != -100
575
+ jsd = jsd[mask]
576
+
577
+ # Apply reduction
578
+ if reduction == "batchmean":
579
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
580
+ elif reduction == "sum":
581
+ return jsd.sum()
582
+ elif reduction == "mean":
583
+ return jsd.mean()
584
+ else:
585
+ return jsd
586
+
587
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
588
+ # compute student output
589
+ outputs_student = model(
590
+ input_ids=inputs["input_ids"],
591
+ attention_mask=inputs["attention_mask"],
592
+ )
593
+
594
+ # compute teacher output in eval mode
595
+ self.teacher_model.eval()
596
+ with torch.no_grad():
597
+ outputs_teacher = self.teacher_model(
598
+ input_ids=inputs["input_ids"],
599
+ attention_mask=inputs["attention_mask"],
600
+ )
601
+
602
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
603
+ prompt_lengths = inputs["prompts"].shape[1]
604
+ shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
605
+ shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
606
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
607
+
608
+ # compute loss
609
+ loss = self.generalized_jsd_loss(
610
+ student_logits=shifted_student_logits,
611
+ teacher_logits=shifted_teacher_logits,
612
+ labels=shifted_labels,
613
+ beta=self.beta,
614
+ )
615
+
616
+ # empty cache
617
+ empty_cache()
618
+
619
+ # Return loss
620
+ return (loss, outputs_student) if return_outputs else loss
621
+
622
+ @staticmethod
623
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
624
+ # Generate output with respect to the prompt-only
625
+ generated_outputs = model.generate(
626
+ input_ids=inputs["prompts"],
627
+ attention_mask=inputs.get("prompt_attention_mask", None),
628
+ generation_config=generation_config,
629
+ return_dict_in_generate=True,
630
+ )
631
+
632
+ # Get the generated token IDs
633
+ generated_tokens = generated_outputs.sequences
634
+ # Calculate new attention mask
635
+ new_attention_mask = torch.ones_like(generated_tokens)
636
+ new_labels = generated_tokens.clone()
637
+
638
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
639
+ if pad_token_id is not None:
640
+ new_labels[new_labels == pad_token_id] = -100
641
+ new_attention_mask[generated_tokens == pad_token_id] = 0
642
+
643
+ return generated_tokens, new_attention_mask, new_labels
644
+
645
+ def training_step(
646
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
647
+ ) -> torch.Tensor:
648
+ """
649
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
650
+
651
+ This method implements the on-policy learning approach described in the GKD paper. With probability
652
+ `self.lmbda`, it generates new responses using the student model, which are then used for training instead of
653
+ the original inputs.
654
+ """
655
+ if self.seq_kd:
656
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
657
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
658
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
659
+ )
660
+ inputs["input_ids"] = new_input_ids
661
+ inputs["attention_mask"] = new_attention_mask
662
+ inputs["labels"] = new_labels
663
+ if random.random() <= self.lmbda:
664
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
665
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
666
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
667
+ )
668
+ inputs["input_ids"] = new_input_ids
669
+ inputs["attention_mask"] = new_attention_mask
670
+ inputs["labels"] = new_labels
671
+
672
+ loss = super().training_step(model, inputs, num_items_in_batch)
673
+ return loss
674
+
675
+ def create_model_card(
676
+ self,
677
+ model_name: Optional[str] = None,
678
+ dataset_name: Optional[str] = None,
679
+ tags: Union[str, list[str], None] = None,
680
+ ):
681
+ """
682
+ Creates a draft of a model card using the information available to the `Trainer`.
683
+
684
+ Args:
685
+ model_name (`str` or `None`, *optional*, defaults to `None`):
686
+ Name of the model.
687
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
688
+ Name of the dataset used for training.
689
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
690
+ Tags to be associated with the model card.
691
+ """
692
+ if not self.is_world_process_zero():
693
+ return
694
+
695
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
696
+ base_model = self.model.config._name_or_path
697
+ else:
698
+ base_model = None
699
+
700
+ # normalize `tags` to a mutable set
701
+ if tags is None:
702
+ tags = set()
703
+ elif isinstance(tags, str):
704
+ tags = {tags}
705
+ else:
706
+ tags = set(tags)
707
+
708
+ if hasattr(self.model.config, "unsloth_version"):
709
+ tags.add("unsloth")
710
+
711
+ tags.update(self._tag_names)
712
+
713
+ citation = textwrap.dedent("""\
714
+ @inproceedings{agarwal2024on-policy,
715
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
716
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
717
+ year = 2024,
718
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
719
+ publisher = {OpenReview.net},
720
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
721
+ }""")
722
+
723
+ model_card = generate_model_card(
724
+ base_model=base_model,
725
+ model_name=model_name,
726
+ hub_model_id=self.hub_model_id,
727
+ dataset_name=dataset_name,
728
+ tags=tags,
729
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
730
+ comet_url=get_comet_experiment_url(),
731
+ trainer_name="GKD",
732
+ trainer_citation=citation,
733
+ paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
734
+ paper_id="2306.13649",
735
+ )
736
+
737
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
738
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
739
+ """
740
+
741
+ """
742
+ def __init__(
743
+ self,
744
+ model = None,
745
+ teacher_model = None,
746
+ args = None,
747
+ data_collator = None,
748
+ train_dataset = None,
749
+ eval_dataset = None,
750
+ processing_class = None,
751
+ compute_metrics = None,
752
+ callbacks = None,
753
+ preprocess_logits_for_metrics = None,
754
+ peft_config = None,
755
+ formatting_func = None,
756
+ **kwargs
757
+ ):
758
+ if args is None: args = UnslothGKDConfig()
759
+ use_bf16 = getattr(args, 'bf16', False)
760
+ if type(use_bf16) is not bool: use_bf16 = False
761
+ use_fp16 = getattr(args, 'fp16', False)
762
+ if type(use_fp16) is not bool: use_fp16 = False
763
+ force_float32 = False
764
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
765
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
766
+ force_float32 = True
767
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
768
+ dtype = getattr(model.config, 'torch_dtype', None)
769
+ if dtype is None: dtype = model.get_input_embeddings().dtype
770
+ from unsloth_zoo.utils import _get_dtype
771
+ dtype = _get_dtype(dtype)
772
+ float16 = dtype == torch.float16
773
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
774
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
775
+ if force_float32:
776
+ args.fp16 = False
777
+ args.bf16 = False
778
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
779
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
780
+ args.fp16 = float16
781
+ args.bf16 = not float16
782
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
783
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
784
+ args.eval_strategy = 'steps'
785
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
786
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
787
+ if ga_steps is not None and ga_steps > 1:
788
+ from transformers import __version__ as transformers_version
789
+ if Version(transformers_version) <= Version('4.45.2'):
790
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
791
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
792
+ if getattr(args, 'eval_strategy', 'no') != 'no':
793
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
794
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
795
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
796
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
797
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
798
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
799
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
800
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
801
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
802
+ if force_float32:
803
+ args.bf16_full_eval = False
804
+ args.fp16_full_eval = False
805
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
806
+ args.bf16_full_eval = True
807
+ args.fp16_full_eval = False
808
+ elif not bf16_full_eval and not fp16_full_eval:
809
+ args.bf16_full_eval = args.bf16
810
+ args.fp16_full_eval = args.fp16
811
+ _output_logits = False
812
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
813
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
814
+ if _output_logits:
815
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
816
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
817
+ pass
818
+ else:
819
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
820
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
821
+ if args_max_seq_length is None and model_max_seq_length is not None:
822
+ max_seq_length = model.max_seq_length
823
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
824
+ if model is not None and hasattr(model, 'for_training'):
825
+ model.for_training()
826
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
827
+ if 'processing_class' in locals():
828
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
829
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
830
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
831
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
832
+ if not isinstance(data_collator, UnslothVisionDataCollator):
833
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
834
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
835
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
836
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
837
+ else:
838
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
839
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
840
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
841
+ if not isinstance(data_collator, UnslothVisionDataCollator):
842
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
843
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
844
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
845
+ else:
846
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
847
+ other_metrics = []
848
+
849
+ from unsloth_zoo.logging_utils import PatchRLStatistics
850
+ PatchRLStatistics('gkd_trainer', other_metrics)
851
+
852
+ super().__init__(
853
+ model = model,
854
+ teacher_model = teacher_model,
855
+ args = args,
856
+ data_collator = data_collator,
857
+ train_dataset = train_dataset,
858
+ eval_dataset = eval_dataset,
859
+ processing_class = processing_class,
860
+ compute_metrics = compute_metrics,
861
+ callbacks = callbacks,
862
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
863
+ peft_config = peft_config,
864
+ formatting_func = formatting_func,**kwargs)
865
+ if hasattr(self, 'neftune_hook_handle'):
866
+ self.neftune_hook_handle.remove()
867
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
868
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
869
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
870
+ pass
871
+ if hasattr(self, 'accelerator'):
872
+ scaler = self.accelerator.scaler
873
+ current_model = model
874
+ while hasattr(current_model, 'model'):
875
+ current_model.accelerator_scaler = scaler
876
+ current_model = current_model.model
877
+ current_model.accelerator_scaler = scaler
878
+ pass
879
+
880
+ pass
unsloth_compiled_cache/UnslothGRPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothIterativeSFTTrainer.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.iterative_sft_trainer import (AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, DataLoader, Dataset, EvalLoopOutput, FeatureExtractionMixin, IterativeSFTConfig, IterativeSFTTrainer, Optional, PPODecorators, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainingArguments, Union, generate_model_card, get_comet_experiment_url, is_peft_available, is_wandb_available, os, torch, wandb, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothIterativeSFTConfig(IterativeSFTConfig):
53
+ """
54
+
55
+ Configuration class for the [`IterativeSFTTrainer`].
56
+
57
+ This class includes only the parameters that are specific to Iterative SFT training. For a full list of training
58
+ arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
59
+ class may differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ > Parameters that control the model
67
+
68
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
69
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
70
+ argument of the [`IterativeSFTTrainer`] is provided as a string.
71
+
72
+ > Parameters that control the data preprocessing
73
+
74
+ max_length (`int` or `None`, *optional*, defaults to `None`):
75
+ Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated.
76
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
77
+ The truncation mode to use, either `"keep_end"` or `"keep_start"`.
78
+ optimize_device_cache (`bool`, *optional*, defaults to `False`):
79
+ Whether to optimize accelerator cache for slightly more memory-efficient training.
80
+
81
+ """
82
+ vllm_sampling_params: Optional[Any] = field(
83
+ default = None,
84
+ metadata = {'help': 'vLLM SamplingParams'},
85
+ )
86
+ unsloth_num_chunks : Optional[int] = field(
87
+ default = -1,
88
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
89
+ )
90
+ max_seq_length : Optional[int] = field(
91
+ default = None,
92
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
93
+ )
94
+ def __init__(
95
+ self,
96
+ output_dir = None,
97
+ overwrite_output_dir = None,
98
+ do_train = False,
99
+ do_eval = False,
100
+ do_predict = False,
101
+ eval_strategy = 'no',
102
+ prediction_loss_only = False,
103
+ per_device_train_batch_size = 4,
104
+ per_device_eval_batch_size = 4,
105
+ per_gpu_train_batch_size = None,
106
+ per_gpu_eval_batch_size = None,
107
+ gradient_accumulation_steps = 2,
108
+ eval_accumulation_steps = 2,
109
+ eval_delay = 0,
110
+ torch_empty_cache_steps = 250,
111
+ learning_rate = 5e-05,
112
+ weight_decay = 0.01,
113
+ adam_beta1 = 0.9,
114
+ adam_beta2 = 0.999,
115
+ adam_epsilon = 1e-08,
116
+ max_grad_norm = 1.0,
117
+ num_train_epochs = 3.0,
118
+ max_steps = -1,
119
+ lr_scheduler_type = 'linear',
120
+ warmup_ratio = 0.1,
121
+ warmup_steps = 0,
122
+ log_level = 'passive',
123
+ log_level_replica = 'warning',
124
+ log_on_each_node = True,
125
+ logging_dir = None,
126
+ logging_strategy = 'steps',
127
+ logging_first_step = False,
128
+ logging_steps = 1,
129
+ logging_nan_inf_filter = False,
130
+ save_strategy = 'steps',
131
+ save_steps = 500,
132
+ save_total_limit = None,
133
+ save_safetensors = True,
134
+ save_on_each_node = False,
135
+ save_only_model = False,
136
+ restore_callback_states_from_checkpoint = False,
137
+ no_cuda = False,
138
+ use_cpu = False,
139
+ use_mps_device = False,
140
+ seed = 3407,
141
+ data_seed = 3407,
142
+ jit_mode_eval = False,
143
+ use_ipex = False,
144
+ bf16 = False,
145
+ fp16 = False,
146
+ fp16_opt_level = 'O1',
147
+ half_precision_backend = 'auto',
148
+ bf16_full_eval = False,
149
+ fp16_full_eval = False,
150
+ tf32 = None,
151
+ local_rank = -1,
152
+ ddp_backend = None,
153
+ tpu_num_cores = None,
154
+ tpu_metrics_debug = False,
155
+ debug = '',
156
+ dataloader_drop_last = False,
157
+ eval_steps = None,
158
+ dataloader_num_workers = 0,
159
+ dataloader_prefetch_factor = None,
160
+ past_index = -1,
161
+ run_name = None,
162
+ disable_tqdm = None,
163
+ remove_unused_columns = True,
164
+ label_names = None,
165
+ load_best_model_at_end = False,
166
+ metric_for_best_model = None,
167
+ greater_is_better = None,
168
+ ignore_data_skip = False,
169
+ fsdp = '',
170
+ fsdp_min_num_params = 0,
171
+ fsdp_config = None,
172
+ fsdp_transformer_layer_cls_to_wrap = None,
173
+ accelerator_config = None,
174
+ deepspeed = None,
175
+ label_smoothing_factor = 0.0,
176
+ optim = 'adamw_8bit',
177
+ optim_args = None,
178
+ adafactor = False,
179
+ group_by_length = False,
180
+ length_column_name = 'length',
181
+ report_to = None,
182
+ ddp_find_unused_parameters = None,
183
+ ddp_bucket_cap_mb = None,
184
+ ddp_broadcast_buffers = None,
185
+ dataloader_pin_memory = True,
186
+ dataloader_persistent_workers = False,
187
+ skip_memory_metrics = True,
188
+ use_legacy_prediction_loop = False,
189
+ push_to_hub = False,
190
+ resume_from_checkpoint = None,
191
+ hub_model_id = None,
192
+ hub_strategy = 'every_save',
193
+ hub_token = None,
194
+ hub_private_repo = None,
195
+ hub_always_push = False,
196
+ hub_revision = None,
197
+ gradient_checkpointing = False,
198
+ gradient_checkpointing_kwargs = None,
199
+ include_inputs_for_metrics = False,
200
+ eval_do_concat_batches = True,
201
+ fp16_backend = 'auto',
202
+ push_to_hub_model_id = None,
203
+ push_to_hub_organization = None,
204
+ push_to_hub_token = None,
205
+ mp_parameters = '',
206
+ auto_find_batch_size = True,
207
+ full_determinism = False,
208
+ torchdynamo = None,
209
+ ray_scope = 'last',
210
+ ddp_timeout = 1800,
211
+ torch_compile = False,
212
+ torch_compile_backend = None,
213
+ torch_compile_mode = None,
214
+ include_tokens_per_second = False,
215
+ include_num_input_tokens_seen = False,
216
+ neftune_noise_alpha = None,
217
+ optim_target_modules = None,
218
+ batch_eval_metrics = False,
219
+ eval_on_start = False,
220
+ use_liger_kernel = False,
221
+ liger_kernel_config = None,
222
+ eval_use_gather_object = False,
223
+ average_tokens_across_devices = True,
224
+ model_init_kwargs = None,
225
+ max_length = None,
226
+ truncation_mode = 'keep_end',
227
+ optimize_device_cache = False,
228
+ vllm_sampling_params = None,
229
+ unsloth_num_chunks = -1,
230
+ max_seq_length = None,
231
+ **kwargs,
232
+ ):
233
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
234
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
235
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
236
+ output_dir = 'unsloth_training_checkpoints'
237
+ save_strategy = 'no'
238
+
239
+ super().__init__(
240
+ output_dir = output_dir,
241
+ overwrite_output_dir = overwrite_output_dir,
242
+ do_train = do_train,
243
+ do_eval = do_eval,
244
+ do_predict = do_predict,
245
+ eval_strategy = eval_strategy,
246
+ prediction_loss_only = prediction_loss_only,
247
+ per_device_train_batch_size = per_device_train_batch_size,
248
+ per_device_eval_batch_size = per_device_eval_batch_size,
249
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
250
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
251
+ gradient_accumulation_steps = gradient_accumulation_steps,
252
+ eval_accumulation_steps = eval_accumulation_steps,
253
+ eval_delay = eval_delay,
254
+ torch_empty_cache_steps = torch_empty_cache_steps,
255
+ learning_rate = learning_rate,
256
+ weight_decay = weight_decay,
257
+ adam_beta1 = adam_beta1,
258
+ adam_beta2 = adam_beta2,
259
+ adam_epsilon = adam_epsilon,
260
+ max_grad_norm = max_grad_norm,
261
+ num_train_epochs = num_train_epochs,
262
+ max_steps = max_steps,
263
+ lr_scheduler_type = lr_scheduler_type,
264
+ warmup_ratio = warmup_ratio,
265
+ warmup_steps = warmup_steps,
266
+ log_level = log_level,
267
+ log_level_replica = log_level_replica,
268
+ log_on_each_node = log_on_each_node,
269
+ logging_dir = logging_dir,
270
+ logging_strategy = logging_strategy,
271
+ logging_first_step = logging_first_step,
272
+ logging_steps = logging_steps,
273
+ logging_nan_inf_filter = logging_nan_inf_filter,
274
+ save_strategy = save_strategy,
275
+ save_steps = save_steps,
276
+ save_total_limit = save_total_limit,
277
+ save_safetensors = save_safetensors,
278
+ save_on_each_node = save_on_each_node,
279
+ save_only_model = save_only_model,
280
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
281
+ no_cuda = no_cuda,
282
+ use_cpu = use_cpu,
283
+ use_mps_device = use_mps_device,
284
+ seed = seed,
285
+ data_seed = data_seed,
286
+ jit_mode_eval = jit_mode_eval,
287
+ use_ipex = use_ipex,
288
+ bf16 = bf16,
289
+ fp16 = fp16,
290
+ fp16_opt_level = fp16_opt_level,
291
+ half_precision_backend = half_precision_backend,
292
+ bf16_full_eval = bf16_full_eval,
293
+ fp16_full_eval = fp16_full_eval,
294
+ tf32 = tf32,
295
+ local_rank = local_rank,
296
+ ddp_backend = ddp_backend,
297
+ tpu_num_cores = tpu_num_cores,
298
+ tpu_metrics_debug = tpu_metrics_debug,
299
+ debug = debug,
300
+ dataloader_drop_last = dataloader_drop_last,
301
+ eval_steps = eval_steps,
302
+ dataloader_num_workers = dataloader_num_workers,
303
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
304
+ past_index = past_index,
305
+ run_name = run_name,
306
+ disable_tqdm = disable_tqdm,
307
+ remove_unused_columns = remove_unused_columns,
308
+ label_names = label_names,
309
+ load_best_model_at_end = load_best_model_at_end,
310
+ metric_for_best_model = metric_for_best_model,
311
+ greater_is_better = greater_is_better,
312
+ ignore_data_skip = ignore_data_skip,
313
+ fsdp = fsdp,
314
+ fsdp_min_num_params = fsdp_min_num_params,
315
+ fsdp_config = fsdp_config,
316
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
317
+ accelerator_config = accelerator_config,
318
+ deepspeed = deepspeed,
319
+ label_smoothing_factor = label_smoothing_factor,
320
+ optim = optim,
321
+ optim_args = optim_args,
322
+ adafactor = adafactor,
323
+ group_by_length = group_by_length,
324
+ length_column_name = length_column_name,
325
+ report_to = report_to,
326
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
327
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
328
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
329
+ dataloader_pin_memory = dataloader_pin_memory,
330
+ dataloader_persistent_workers = dataloader_persistent_workers,
331
+ skip_memory_metrics = skip_memory_metrics,
332
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
333
+ push_to_hub = push_to_hub,
334
+ resume_from_checkpoint = resume_from_checkpoint,
335
+ hub_model_id = hub_model_id,
336
+ hub_strategy = hub_strategy,
337
+ hub_token = hub_token,
338
+ hub_private_repo = hub_private_repo,
339
+ hub_always_push = hub_always_push,
340
+ hub_revision = hub_revision,
341
+ gradient_checkpointing = gradient_checkpointing,
342
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
343
+ include_inputs_for_metrics = include_inputs_for_metrics,
344
+ eval_do_concat_batches = eval_do_concat_batches,
345
+ fp16_backend = fp16_backend,
346
+ push_to_hub_model_id = push_to_hub_model_id,
347
+ push_to_hub_organization = push_to_hub_organization,
348
+ push_to_hub_token = push_to_hub_token,
349
+ mp_parameters = mp_parameters,
350
+ auto_find_batch_size = auto_find_batch_size,
351
+ full_determinism = full_determinism,
352
+ torchdynamo = torchdynamo,
353
+ ray_scope = ray_scope,
354
+ ddp_timeout = ddp_timeout,
355
+ torch_compile = torch_compile,
356
+ torch_compile_backend = torch_compile_backend,
357
+ torch_compile_mode = torch_compile_mode,
358
+ include_tokens_per_second = include_tokens_per_second,
359
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
360
+ neftune_noise_alpha = neftune_noise_alpha,
361
+ optim_target_modules = optim_target_modules,
362
+ batch_eval_metrics = batch_eval_metrics,
363
+ eval_on_start = eval_on_start,
364
+ use_liger_kernel = use_liger_kernel,
365
+ liger_kernel_config = liger_kernel_config,
366
+ eval_use_gather_object = eval_use_gather_object,
367
+ average_tokens_across_devices = average_tokens_across_devices,
368
+ model_init_kwargs = model_init_kwargs,
369
+ max_length = max_length,
370
+ truncation_mode = truncation_mode,
371
+ optimize_device_cache = optimize_device_cache,**kwargs)
372
+ self.vllm_sampling_params = vllm_sampling_params
373
+ self.unsloth_num_chunks = unsloth_num_chunks
374
+ self.max_seq_length = max_seq_length
375
+ pass
376
+
377
+ class _UnslothIterativeSFTTrainer(Trainer):
378
+ """"""
379
+
380
+ _tag_names = ["trl", "iterative-sft"]
381
+
382
+ def __init__(
383
+ self,
384
+ model: Union[str, PreTrainedModel],
385
+ args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None,
386
+ data_collator: Optional[DataCollator] = None,
387
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
388
+ processing_class: Optional[
389
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
390
+ ] = None,
391
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
392
+ None,
393
+ None,
394
+ ),
395
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
396
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
397
+ ):
398
+ # Args
399
+ model_id = model if isinstance(model, str) else model.config._name_or_path
400
+ if args is None:
401
+ model_name = model_id.split("/")[-1]
402
+ args = IterativeSFTConfig(f"{model_name}-IterativeSFT")
403
+ elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig):
404
+ dict_args = args.to_dict()
405
+ dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
406
+ dict_args.pop("push_to_hub_token")
407
+ args = IterativeSFTConfig(**dict_args)
408
+
409
+ # Handle the tokenizer
410
+ if processing_class is None:
411
+ processing_class = AutoTokenizer.from_pretrained(model_id)
412
+
413
+ # Model
414
+ if args.model_init_kwargs is not None and not isinstance(model, str):
415
+ warnings.warn(
416
+ "You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. "
417
+ "The `model_init_kwargs` will be ignored."
418
+ )
419
+ if isinstance(model, str):
420
+ model = self._create_model_from_path(model, args)
421
+
422
+ # PEFT configuration and model wrapping
423
+ if is_peft_available() and isinstance(model, PeftModel):
424
+ self.is_peft_model = True
425
+ else:
426
+ self.is_peft_model = False
427
+
428
+ self.processing_class = processing_class
429
+ self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
430
+
431
+ if data_collator is None:
432
+ if self.is_encoder_decoder:
433
+ self.data_collator = DataCollatorForSeq2Seq(
434
+ processing_class, label_pad_token_id=-100, pad_to_multiple_of=8
435
+ )
436
+ else:
437
+ self.data_collator = DataCollatorForLanguageModeling(self.processing_class, mlm=False)
438
+ else:
439
+ self.data_collator = data_collator
440
+
441
+ self.max_length = args.max_length
442
+ self.truncation_mode = args.truncation_mode
443
+ self.optimize_device_cache = args.optimize_device_cache
444
+
445
+ super().__init__(
446
+ model=model,
447
+ args=args,
448
+ data_collator=self.data_collator,
449
+ eval_dataset=eval_dataset,
450
+ processing_class=processing_class,
451
+ compute_metrics=compute_metrics,
452
+ optimizers=optimizers,
453
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
454
+ )
455
+
456
+ # Add tags for models that have been loaded with the correct transformers version
457
+ if hasattr(self.model, "add_model_tags"):
458
+ self.model.add_model_tags(self._tag_names)
459
+
460
+ self.create_optimizer_and_scheduler(self.args.max_steps)
461
+
462
+ # prepare model, optimizer and lr_scheduler
463
+ self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
464
+ self.model, self.optimizer, self.lr_scheduler
465
+ )
466
+
467
+ self.processing_class.truncation_side = "left" if self.truncation_mode == "keep_end" else "right"
468
+
469
+ if not hasattr(self, "accelerator"):
470
+ raise AttributeError(
471
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
472
+ )
473
+
474
+ PPODecorators.optimize_device_cache = self.optimize_device_cache
475
+
476
+ def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel:
477
+ """Creates a model from a path or model identifier."""
478
+ model_init_kwargs = args.model_init_kwargs or {}
479
+ return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
480
+
481
+ def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
482
+ if attention_mask is None:
483
+ attention_mask = [torch.ones_like(ids) for ids in input_ids]
484
+
485
+ if self.is_encoder_decoder:
486
+ input_data = self.data_collator(
487
+ [
488
+ {"input_ids": ids, "attention_mask": att, "labels": lab}
489
+ for ids, att, lab in zip(input_ids, attention_mask, labels)
490
+ ]
491
+ ).to(self.model.device)
492
+
493
+ input_data.pop("decoder_input_ids", None) # This is directly computed inside the model
494
+
495
+ input_data["labels"][input_data["labels"] == self.processing_class.pad_token_id] = -100
496
+
497
+ else:
498
+ input_data = self.data_collator(
499
+ [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]
500
+ ).to(self.model.device)
501
+
502
+ # truncate in case the user has provided input_ids, attention_mask and labels
503
+ if self.max_length is not None:
504
+ if self.truncation_mode == "keep_start":
505
+ input_data = {k: v[: self.max_length] for k, v in input_data.items()}
506
+ elif self.truncation_mode == "keep_end":
507
+ input_data = {k: v[-self.max_length :] for k, v in input_data.items()}
508
+ else:
509
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
510
+
511
+ return input_data
512
+
513
+ @staticmethod
514
+ def _step_safety_checker(
515
+ input_ids: list[torch.LongTensor],
516
+ attention_mask: list[torch.LongTensor],
517
+ labels: list[torch.LongTensor],
518
+ texts: list[str],
519
+ texts_labels: list[str],
520
+ ):
521
+ """
522
+ Check if the input data is valid for training.
523
+
524
+ Args:
525
+ input_ids (list[`torch.LongTensor`]):
526
+ List of tensors containing the input_ids
527
+ attention_mask (list[`torch.LongTensor`]):
528
+ List of tensors containing the attention_mask
529
+ labels (list[`torch.FloatTensor`]):
530
+ List of tensors containing the labels
531
+ texts (list[`str`]):
532
+ List of string containing the text input.
533
+ texts_labels (list[`str`]):
534
+ List of string containing the text labels.
535
+
536
+ Returns:
537
+ `tuple`: The input data.
538
+ """
539
+ if texts is None:
540
+ if attention_mask is None:
541
+ for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]):
542
+ if not isinstance(tensor_list, list):
543
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
544
+ if not isinstance(tensor_list[0], torch.Tensor):
545
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
546
+ else:
547
+ for name, tensor_list in zip(
548
+ ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels]
549
+ ):
550
+ if not isinstance(tensor_list, list):
551
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
552
+ if not isinstance(tensor_list[0], torch.Tensor):
553
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
554
+ else:
555
+ if not isinstance(texts, list):
556
+ raise ValueError(f"'text' must be a list of strings - got {type(texts)}")
557
+ if not isinstance(texts[0], str):
558
+ raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}")
559
+ if texts_labels is not None:
560
+ if not isinstance(texts_labels, list):
561
+ raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}")
562
+ if not isinstance(texts_labels[0], str):
563
+ raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}")
564
+
565
+ return input_ids, attention_mask, labels, texts, texts_labels
566
+
567
+ @PPODecorators.empty_device_cache()
568
+ def step(
569
+ self,
570
+ input_ids: Optional[list[torch.LongTensor]] = None,
571
+ attention_mask: Optional[list[torch.LongTensor]] = None,
572
+ labels: Optional[list[torch.LongTensor]] = None,
573
+ texts: Optional[list[str]] = None,
574
+ texts_labels: Optional[list[str]] = None,
575
+ ):
576
+ """
577
+ Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and
578
+ text_labels.
579
+
580
+ Args:
581
+ input_ids (list[`torch.LongTensor`]):
582
+ List of tensors containing the input_ids (if not provided, text will be used)
583
+ attention_mask (list[`torch.LongTensor`], , *optional*):
584
+ List of tensors containing the attention_mask
585
+ labels (list[`torch.FloatTensor`], *optional*):
586
+ List of tensors containing the labels (if set to None, will default to input_ids)
587
+ texts (list[`str`], *optional*):
588
+ List of strings containing the text input (if not provided, input_ids will directly be used)
589
+ texts_labels (list[`str`], *optional*):
590
+ List of strings containing the text labels (if set to None, will default to text)
591
+
592
+ Returns:
593
+ `dict[str, Any]`: A summary of the training statistics
594
+ """
595
+ self.model.train()
596
+
597
+ if self.state.global_step == 0:
598
+ self.tr_loss = torch.tensor(0.0).to(self.args.device)
599
+ self._globalstep_last_logged = self.state.global_step
600
+
601
+ if input_ids is None and texts is None:
602
+ raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.")
603
+ elif input_ids is not None and texts is not None:
604
+ warnings.warn(
605
+ "Both `input_ids` and `texts` argument are provided. `input_ids` will be ignored. "
606
+ "Please provide only one of the two.",
607
+ UserWarning,
608
+ )
609
+
610
+ if labels is None and texts_labels is None and self.is_encoder_decoder:
611
+ raise ValueError(
612
+ "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed."
613
+ )
614
+
615
+ # Convert Column to list if not already
616
+ input_ids = input_ids[:] if input_ids is not None else None
617
+ attention_mask = attention_mask[:] if attention_mask is not None else None
618
+ labels = labels[:] if labels is not None else None
619
+ texts = texts[:] if texts is not None else None
620
+ texts_labels = texts_labels[:] if texts_labels is not None else None
621
+
622
+ input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(
623
+ input_ids, attention_mask, labels, texts, texts_labels
624
+ )
625
+
626
+ if texts is not None:
627
+ model_inputs = self.processing_class(
628
+ texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
629
+ )
630
+
631
+ input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
632
+
633
+ if texts_labels is not None:
634
+ labels = self.processing_class(
635
+ texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
636
+ )["input_ids"]
637
+
638
+ if labels is None:
639
+ labels = input_ids
640
+
641
+ model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels)
642
+
643
+ model_inputs_names = list(model_inputs.keys())
644
+
645
+ batch_dict = {}
646
+ batch_dict.update(model_inputs)
647
+
648
+ def collator(data):
649
+ return_dict = dict()
650
+ for key in data[0]:
651
+ if key in ["input_ids", "attention_mask", "labels"]:
652
+ return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device)
653
+ return return_dict
654
+
655
+ batch_data = Dataset.from_dict(batch_dict)
656
+ batch_data.set_format("torch")
657
+
658
+ step_dataloader = DataLoader(
659
+ batch_data,
660
+ batch_size=self.args.per_device_train_batch_size,
661
+ shuffle=True,
662
+ collate_fn=collator,
663
+ )
664
+
665
+ for _, batch in enumerate(step_dataloader):
666
+ with self.accelerator.accumulate(self.model):
667
+ model_inputs = {k: batch[k] for k in model_inputs_names}
668
+ loss = self.compute_loss(self.model, model_inputs)
669
+
670
+ if self.args.n_gpu > 1:
671
+ loss = loss.mean()
672
+
673
+ tr_loss_step = loss.detach()
674
+
675
+ self.accelerator.backward(loss)
676
+
677
+ if self.accelerator.sync_gradients and self.args.max_grad_norm is not None:
678
+ self.accelerator.clip_grad_norm_(
679
+ self.model.parameters(),
680
+ self.args.max_grad_norm,
681
+ )
682
+
683
+ self.optimizer.step()
684
+ self.optimizer.zero_grad()
685
+ if self.lr_scheduler is not None:
686
+ self.lr_scheduler.step()
687
+
688
+ self.state.global_step += 1
689
+
690
+ # update stats etc
691
+ self.tr_loss += tr_loss_step
692
+
693
+ self._maybe_log_save_evaluate()
694
+
695
+ def _maybe_log_save_evaluate(self):
696
+ # check if eval is required
697
+ if self.args.eval_steps is not None:
698
+ if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0:
699
+ self.evaluate(self.eval_dataset)
700
+
701
+ # check if logging is required
702
+ if self.args.logging_steps is not None:
703
+ if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0:
704
+ logs: dict[str, float] = {}
705
+
706
+ tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item()
707
+
708
+ # reset tr_loss to zero
709
+ self.tr_loss -= self.tr_loss
710
+
711
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
712
+ logs["learning_rate"] = self._get_learning_rate()
713
+
714
+ self._globalstep_last_logged = self.state.global_step
715
+
716
+ self.log(logs)
717
+
718
+ # Ensure the model card is saved along with the checkpoint
719
+ def _save_checkpoint(self, model, trial):
720
+ if self.args.hub_model_id is None:
721
+ model_name = Path(self.args.output_dir).name
722
+ else:
723
+ model_name = self.args.hub_model_id.split("/")[-1]
724
+ self.create_model_card(model_name=model_name)
725
+ super()._save_checkpoint(model, trial)
726
+
727
+ def create_model_card(
728
+ self,
729
+ model_name: Optional[str] = None,
730
+ dataset_name: Optional[str] = None,
731
+ tags: Union[str, list[str], None] = None,
732
+ ):
733
+ """
734
+ Creates a draft of a model card using the information available to the `Trainer`.
735
+
736
+ Args:
737
+ model_name (`str` or `None`, *optional*, defaults to `None`):
738
+ Name of the model.
739
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
740
+ Name of the dataset used for training.
741
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
742
+ Tags to be associated with the model card.
743
+ """
744
+ if not self.is_world_process_zero():
745
+ return
746
+
747
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
748
+ base_model = self.model.config._name_or_path
749
+ else:
750
+ base_model = None
751
+
752
+ # normalize `tags` to a mutable set
753
+ if tags is None:
754
+ tags = set()
755
+ elif isinstance(tags, str):
756
+ tags = {tags}
757
+ else:
758
+ tags = set(tags)
759
+
760
+ if hasattr(self.model.config, "unsloth_version"):
761
+ tags.add("unsloth")
762
+
763
+ tags.update(self._tag_names)
764
+
765
+ model_card = generate_model_card(
766
+ base_model=base_model,
767
+ model_name=model_name,
768
+ hub_model_id=self.hub_model_id,
769
+ dataset_name=dataset_name,
770
+ tags=tags,
771
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
772
+ comet_url=get_comet_experiment_url(),
773
+ trainer_name="Iterative SFT",
774
+ )
775
+
776
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
777
+ class UnslothIterativeSFTTrainer(_UnslothIterativeSFTTrainer):
778
+ """
779
+
780
+ The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
781
+
782
+ Args:
783
+ model (`Union[str, PreTrainedModel]`):
784
+ Model to be trained. Can be either:
785
+
786
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
787
+ path to a *directory* containing model weights saved using
788
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
789
+ using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
790
+ `args.model_init_kwargs`.
791
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
792
+ args ([`IterativeSFTConfig`], *optional*, defaults to `None`):
793
+ Configuration for this trainer. If `None`, a default configuration is used.
794
+ data_collator (`DataCollator`, *optional*):
795
+ Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
796
+ Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
797
+ of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
798
+ tokenizer.
799
+ eval_dataset (`datasets.Dataset`):
800
+ The dataset to use for evaluation.
801
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
802
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
803
+ with [`~transformers.AutoTokenizer.from_pretrained`].
804
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
805
+ The optimizer and scheduler to use for training.
806
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
807
+ The function to use to preprocess the logits before computing the metrics.
808
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
809
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
810
+ metric values.
811
+
812
+ """
813
+ def __init__(
814
+ self,
815
+ model,
816
+ args = None,
817
+ data_collator = None,
818
+ eval_dataset = None,
819
+ processing_class = None,
820
+ preprocess_logits_for_metrics = None,
821
+ compute_metrics = None,
822
+ **kwargs
823
+ ):
824
+ if args is None: args = UnslothIterativeSFTConfig()
825
+ use_bf16 = getattr(args, 'bf16', False)
826
+ if type(use_bf16) is not bool: use_bf16 = False
827
+ use_fp16 = getattr(args, 'fp16', False)
828
+ if type(use_fp16) is not bool: use_fp16 = False
829
+ force_float32 = False
830
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
831
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
832
+ force_float32 = True
833
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
834
+ dtype = getattr(model.config, 'torch_dtype', None)
835
+ if dtype is None: dtype = model.get_input_embeddings().dtype
836
+ from unsloth_zoo.utils import _get_dtype
837
+ dtype = _get_dtype(dtype)
838
+ float16 = dtype == torch.float16
839
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
840
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
841
+ if force_float32:
842
+ args.fp16 = False
843
+ args.bf16 = False
844
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
845
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
846
+ args.fp16 = float16
847
+ args.bf16 = not float16
848
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
849
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
850
+ args.eval_strategy = 'steps'
851
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
852
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
853
+ if ga_steps is not None and ga_steps > 1:
854
+ from transformers import __version__ as transformers_version
855
+ if Version(transformers_version) <= Version('4.45.2'):
856
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
857
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
858
+ if getattr(args, 'eval_strategy', 'no') != 'no':
859
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
860
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
861
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
862
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
863
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
864
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
865
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
866
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
867
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
868
+ if force_float32:
869
+ args.bf16_full_eval = False
870
+ args.fp16_full_eval = False
871
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
872
+ args.bf16_full_eval = True
873
+ args.fp16_full_eval = False
874
+ elif not bf16_full_eval and not fp16_full_eval:
875
+ args.bf16_full_eval = args.bf16
876
+ args.fp16_full_eval = args.fp16
877
+ _output_logits = False
878
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
879
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
880
+ if _output_logits:
881
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
882
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
883
+ pass
884
+ else:
885
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
886
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
887
+ if args_max_seq_length is None and model_max_seq_length is not None:
888
+ max_seq_length = model.max_seq_length
889
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
890
+ if model is not None and hasattr(model, 'for_training'):
891
+ model.for_training()
892
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
893
+ if 'processing_class' in locals():
894
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
895
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
896
+ other_metrics = []
897
+
898
+ from unsloth_zoo.logging_utils import PatchRLStatistics
899
+ PatchRLStatistics('iterative_sft_trainer', other_metrics)
900
+
901
+ super().__init__(
902
+ model = model,
903
+ args = args,
904
+ data_collator = data_collator,
905
+ eval_dataset = eval_dataset,
906
+ processing_class = processing_class,
907
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
908
+ compute_metrics = compute_metrics,**kwargs)
909
+ if hasattr(self, 'neftune_hook_handle'):
910
+ self.neftune_hook_handle.remove()
911
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
912
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
913
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
914
+ pass
915
+ if hasattr(self, 'accelerator'):
916
+ scaler = self.accelerator.scaler
917
+ current_model = model
918
+ while hasattr(current_model, 'model'):
919
+ current_model.accelerator_scaler = scaler
920
+ current_model = current_model.model
921
+ current_model.accelerator_scaler = scaler
922
+ pass
923
+
924
+ pass
unsloth_compiled_cache/UnslothKTOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothNashMDTrainer.py ADDED
@@ -0,0 +1,1019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothNashMDConfig(NashMDConfig):
53
+ """
54
+
55
+ Configuration class for the [`NashMDTrainer`].
56
+
57
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
58
+
59
+ Parameters:
60
+ mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
61
+ Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
62
+ mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
63
+ epochs.
64
+
65
+ """
66
+ vllm_sampling_params: Optional[Any] = field(
67
+ default = None,
68
+ metadata = {'help': 'vLLM SamplingParams'},
69
+ )
70
+ unsloth_num_chunks : Optional[int] = field(
71
+ default = -1,
72
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
73
+ )
74
+ max_seq_length : Optional[int] = field(
75
+ default = None,
76
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
77
+ )
78
+ def __init__(
79
+ self,
80
+ output_dir = None,
81
+ overwrite_output_dir = None,
82
+ do_train = False,
83
+ do_eval = False,
84
+ do_predict = False,
85
+ eval_strategy = 'no',
86
+ prediction_loss_only = False,
87
+ per_device_train_batch_size = 4,
88
+ per_device_eval_batch_size = 4,
89
+ per_gpu_train_batch_size = None,
90
+ per_gpu_eval_batch_size = None,
91
+ gradient_accumulation_steps = 2,
92
+ eval_accumulation_steps = 2,
93
+ eval_delay = 0,
94
+ torch_empty_cache_steps = 250,
95
+ learning_rate = 5e-05,
96
+ weight_decay = 0.01,
97
+ adam_beta1 = 0.9,
98
+ adam_beta2 = 0.999,
99
+ adam_epsilon = 1e-08,
100
+ max_grad_norm = 1.0,
101
+ num_train_epochs = 3.0,
102
+ max_steps = -1,
103
+ lr_scheduler_type = 'linear',
104
+ warmup_ratio = 0.1,
105
+ warmup_steps = 0,
106
+ log_level = 'passive',
107
+ log_level_replica = 'warning',
108
+ log_on_each_node = True,
109
+ logging_dir = None,
110
+ logging_strategy = 'steps',
111
+ logging_first_step = False,
112
+ logging_steps = 1,
113
+ logging_nan_inf_filter = False,
114
+ save_strategy = 'steps',
115
+ save_steps = 500,
116
+ save_total_limit = None,
117
+ save_safetensors = True,
118
+ save_on_each_node = False,
119
+ save_only_model = False,
120
+ restore_callback_states_from_checkpoint = False,
121
+ no_cuda = False,
122
+ use_cpu = False,
123
+ use_mps_device = False,
124
+ seed = 3407,
125
+ data_seed = 3407,
126
+ jit_mode_eval = False,
127
+ use_ipex = False,
128
+ bf16 = False,
129
+ fp16 = False,
130
+ fp16_opt_level = 'O1',
131
+ half_precision_backend = 'auto',
132
+ bf16_full_eval = False,
133
+ fp16_full_eval = False,
134
+ tf32 = None,
135
+ local_rank = -1,
136
+ ddp_backend = None,
137
+ tpu_num_cores = None,
138
+ tpu_metrics_debug = False,
139
+ debug = '',
140
+ dataloader_drop_last = False,
141
+ eval_steps = None,
142
+ dataloader_num_workers = 0,
143
+ dataloader_prefetch_factor = None,
144
+ past_index = -1,
145
+ run_name = None,
146
+ disable_tqdm = None,
147
+ remove_unused_columns = True,
148
+ label_names = None,
149
+ load_best_model_at_end = False,
150
+ metric_for_best_model = None,
151
+ greater_is_better = None,
152
+ ignore_data_skip = False,
153
+ fsdp = '',
154
+ fsdp_min_num_params = 0,
155
+ fsdp_config = None,
156
+ fsdp_transformer_layer_cls_to_wrap = None,
157
+ accelerator_config = None,
158
+ deepspeed = None,
159
+ label_smoothing_factor = 0.0,
160
+ optim = 'adamw_8bit',
161
+ optim_args = None,
162
+ adafactor = False,
163
+ group_by_length = False,
164
+ length_column_name = 'length',
165
+ report_to = None,
166
+ ddp_find_unused_parameters = None,
167
+ ddp_bucket_cap_mb = None,
168
+ ddp_broadcast_buffers = None,
169
+ dataloader_pin_memory = True,
170
+ dataloader_persistent_workers = False,
171
+ skip_memory_metrics = True,
172
+ use_legacy_prediction_loop = False,
173
+ push_to_hub = False,
174
+ resume_from_checkpoint = None,
175
+ hub_model_id = None,
176
+ hub_strategy = 'every_save',
177
+ hub_token = None,
178
+ hub_private_repo = None,
179
+ hub_always_push = False,
180
+ hub_revision = None,
181
+ gradient_checkpointing = False,
182
+ gradient_checkpointing_kwargs = None,
183
+ include_inputs_for_metrics = False,
184
+ eval_do_concat_batches = True,
185
+ fp16_backend = 'auto',
186
+ push_to_hub_model_id = None,
187
+ push_to_hub_organization = None,
188
+ push_to_hub_token = None,
189
+ mp_parameters = '',
190
+ auto_find_batch_size = True,
191
+ full_determinism = False,
192
+ torchdynamo = None,
193
+ ray_scope = 'last',
194
+ ddp_timeout = 1800,
195
+ torch_compile = False,
196
+ torch_compile_backend = None,
197
+ torch_compile_mode = None,
198
+ include_tokens_per_second = False,
199
+ include_num_input_tokens_seen = False,
200
+ neftune_noise_alpha = None,
201
+ optim_target_modules = None,
202
+ batch_eval_metrics = False,
203
+ eval_on_start = False,
204
+ use_liger_kernel = False,
205
+ liger_kernel_config = None,
206
+ eval_use_gather_object = False,
207
+ average_tokens_across_devices = True,
208
+ reward_model_path = None,
209
+ judge = None,
210
+ max_new_tokens = 64,
211
+ max_length = 512,
212
+ temperature = 0.9,
213
+ missing_eos_penalty = None,
214
+ loss_type = 'sigmoid',
215
+ dataset_num_proc = None,
216
+ disable_dropout = True,
217
+ use_vllm = False,
218
+ vllm_model_impl = 'vllm',
219
+ gpu_memory_utilization = 0.55,
220
+ ds3_gather_for_generation = True,
221
+ model_init_kwargs = None,
222
+ vllm_sampling_params = None,
223
+ unsloth_num_chunks = -1,
224
+ max_seq_length = None,
225
+ **kwargs,
226
+ ):
227
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
228
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
229
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
230
+ output_dir = 'unsloth_training_checkpoints'
231
+ save_strategy = 'no'
232
+ if dataset_num_proc is None:
233
+ from multiprocessing import cpu_count
234
+ dataset_num_proc = min(cpu_count()*2, 2)
235
+ if temperature <= 0:
236
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
237
+ elif temperature >= 10:
238
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
239
+
240
+
241
+ super().__init__(
242
+ output_dir = output_dir,
243
+ overwrite_output_dir = overwrite_output_dir,
244
+ do_train = do_train,
245
+ do_eval = do_eval,
246
+ do_predict = do_predict,
247
+ eval_strategy = eval_strategy,
248
+ prediction_loss_only = prediction_loss_only,
249
+ per_device_train_batch_size = per_device_train_batch_size,
250
+ per_device_eval_batch_size = per_device_eval_batch_size,
251
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
252
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
253
+ gradient_accumulation_steps = gradient_accumulation_steps,
254
+ eval_accumulation_steps = eval_accumulation_steps,
255
+ eval_delay = eval_delay,
256
+ torch_empty_cache_steps = torch_empty_cache_steps,
257
+ learning_rate = learning_rate,
258
+ weight_decay = weight_decay,
259
+ adam_beta1 = adam_beta1,
260
+ adam_beta2 = adam_beta2,
261
+ adam_epsilon = adam_epsilon,
262
+ max_grad_norm = max_grad_norm,
263
+ num_train_epochs = num_train_epochs,
264
+ max_steps = max_steps,
265
+ lr_scheduler_type = lr_scheduler_type,
266
+ warmup_ratio = warmup_ratio,
267
+ warmup_steps = warmup_steps,
268
+ log_level = log_level,
269
+ log_level_replica = log_level_replica,
270
+ log_on_each_node = log_on_each_node,
271
+ logging_dir = logging_dir,
272
+ logging_strategy = logging_strategy,
273
+ logging_first_step = logging_first_step,
274
+ logging_steps = logging_steps,
275
+ logging_nan_inf_filter = logging_nan_inf_filter,
276
+ save_strategy = save_strategy,
277
+ save_steps = save_steps,
278
+ save_total_limit = save_total_limit,
279
+ save_safetensors = save_safetensors,
280
+ save_on_each_node = save_on_each_node,
281
+ save_only_model = save_only_model,
282
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
283
+ no_cuda = no_cuda,
284
+ use_cpu = use_cpu,
285
+ use_mps_device = use_mps_device,
286
+ seed = seed,
287
+ data_seed = data_seed,
288
+ jit_mode_eval = jit_mode_eval,
289
+ use_ipex = use_ipex,
290
+ bf16 = bf16,
291
+ fp16 = fp16,
292
+ fp16_opt_level = fp16_opt_level,
293
+ half_precision_backend = half_precision_backend,
294
+ bf16_full_eval = bf16_full_eval,
295
+ fp16_full_eval = fp16_full_eval,
296
+ tf32 = tf32,
297
+ local_rank = local_rank,
298
+ ddp_backend = ddp_backend,
299
+ tpu_num_cores = tpu_num_cores,
300
+ tpu_metrics_debug = tpu_metrics_debug,
301
+ debug = debug,
302
+ dataloader_drop_last = dataloader_drop_last,
303
+ eval_steps = eval_steps,
304
+ dataloader_num_workers = dataloader_num_workers,
305
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
306
+ past_index = past_index,
307
+ run_name = run_name,
308
+ disable_tqdm = disable_tqdm,
309
+ remove_unused_columns = remove_unused_columns,
310
+ label_names = label_names,
311
+ load_best_model_at_end = load_best_model_at_end,
312
+ metric_for_best_model = metric_for_best_model,
313
+ greater_is_better = greater_is_better,
314
+ ignore_data_skip = ignore_data_skip,
315
+ fsdp = fsdp,
316
+ fsdp_min_num_params = fsdp_min_num_params,
317
+ fsdp_config = fsdp_config,
318
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
319
+ accelerator_config = accelerator_config,
320
+ deepspeed = deepspeed,
321
+ label_smoothing_factor = label_smoothing_factor,
322
+ optim = optim,
323
+ optim_args = optim_args,
324
+ adafactor = adafactor,
325
+ group_by_length = group_by_length,
326
+ length_column_name = length_column_name,
327
+ report_to = report_to,
328
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
329
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
330
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
331
+ dataloader_pin_memory = dataloader_pin_memory,
332
+ dataloader_persistent_workers = dataloader_persistent_workers,
333
+ skip_memory_metrics = skip_memory_metrics,
334
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
335
+ push_to_hub = push_to_hub,
336
+ resume_from_checkpoint = resume_from_checkpoint,
337
+ hub_model_id = hub_model_id,
338
+ hub_strategy = hub_strategy,
339
+ hub_token = hub_token,
340
+ hub_private_repo = hub_private_repo,
341
+ hub_always_push = hub_always_push,
342
+ hub_revision = hub_revision,
343
+ gradient_checkpointing = gradient_checkpointing,
344
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
345
+ include_inputs_for_metrics = include_inputs_for_metrics,
346
+ eval_do_concat_batches = eval_do_concat_batches,
347
+ fp16_backend = fp16_backend,
348
+ push_to_hub_model_id = push_to_hub_model_id,
349
+ push_to_hub_organization = push_to_hub_organization,
350
+ push_to_hub_token = push_to_hub_token,
351
+ mp_parameters = mp_parameters,
352
+ auto_find_batch_size = auto_find_batch_size,
353
+ full_determinism = full_determinism,
354
+ torchdynamo = torchdynamo,
355
+ ray_scope = ray_scope,
356
+ ddp_timeout = ddp_timeout,
357
+ torch_compile = torch_compile,
358
+ torch_compile_backend = torch_compile_backend,
359
+ torch_compile_mode = torch_compile_mode,
360
+ include_tokens_per_second = include_tokens_per_second,
361
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
362
+ neftune_noise_alpha = neftune_noise_alpha,
363
+ optim_target_modules = optim_target_modules,
364
+ batch_eval_metrics = batch_eval_metrics,
365
+ eval_on_start = eval_on_start,
366
+ use_liger_kernel = use_liger_kernel,
367
+ liger_kernel_config = liger_kernel_config,
368
+ eval_use_gather_object = eval_use_gather_object,
369
+ average_tokens_across_devices = average_tokens_across_devices,
370
+ reward_model_path = reward_model_path,
371
+ judge = judge,
372
+ max_new_tokens = max_new_tokens,
373
+ max_length = max_length,
374
+ temperature = temperature,
375
+ missing_eos_penalty = missing_eos_penalty,
376
+ loss_type = loss_type,
377
+ dataset_num_proc = dataset_num_proc,
378
+ disable_dropout = disable_dropout,
379
+ use_vllm = use_vllm,
380
+ vllm_model_impl = vllm_model_impl,
381
+ gpu_memory_utilization = gpu_memory_utilization,
382
+ ds3_gather_for_generation = ds3_gather_for_generation,
383
+ model_init_kwargs = model_init_kwargs,**kwargs)
384
+ self.vllm_sampling_params = vllm_sampling_params
385
+ self.unsloth_num_chunks = unsloth_num_chunks
386
+ self.max_seq_length = max_seq_length
387
+ pass
388
+
389
+ class _UnslothNashMDTrainer(OnlineDPOTrainer):
390
+ r""""""
391
+
392
+ _tag_names = ["trl", "nash-md"]
393
+
394
+ def __init__(
395
+ self,
396
+ model: Union[PreTrainedModel, nn.Module] = None,
397
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
398
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
399
+ judge: Optional[BasePairwiseJudge] = None,
400
+ args: Optional[NashMDConfig] = None,
401
+ data_collator: Optional[Callable] = None,
402
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
403
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
404
+ processing_class: Optional[
405
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
406
+ ] = None,
407
+ peft_config: Optional[dict] = None,
408
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
409
+ callbacks: Optional[list[TrainerCallback]] = None,
410
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
411
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
412
+ ) -> None:
413
+ super().__init__(
414
+ model=model,
415
+ ref_model=ref_model,
416
+ reward_model=reward_model,
417
+ judge=judge,
418
+ args=args,
419
+ data_collator=data_collator,
420
+ train_dataset=train_dataset,
421
+ eval_dataset=eval_dataset,
422
+ processing_class=processing_class,
423
+ reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
424
+ peft_config=peft_config,
425
+ compute_metrics=compute_metrics,
426
+ callbacks=callbacks,
427
+ optimizers=optimizers,
428
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
429
+ )
430
+
431
+ self._mixture_coef = self.args.mixture_coef
432
+
433
+ # Overwrite the stats dictionary to include NashMD specific statistics
434
+ self.stats = {
435
+ # Remove "non_score_reward", "rlhf_reward", "scores_margin"
436
+ # Add "mixture_coef"
437
+ "loss/kl": [],
438
+ "objective/entropy": [],
439
+ "loss/score": [],
440
+ "rewards/probabilities": [],
441
+ "rewards/accuracies": [],
442
+ "rewards/margins": [],
443
+ "logps/chosen": [],
444
+ "logps/rejected": [],
445
+ "val/model_contain_eos_token": [],
446
+ "val/ref_contain_eos_token": [],
447
+ "beta": [],
448
+ "mixture_coef": [],
449
+ }
450
+ if self.reward_model is not None:
451
+ self.stats["rewards/chosen"] = []
452
+ self.stats["rewards/rejected"] = []
453
+
454
+ @property
455
+ def mixture_coef(self):
456
+ if isinstance(self._mixture_coef, list):
457
+ epoch = self.state.epoch
458
+ return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
459
+ else:
460
+ return self._mixture_coef
461
+
462
+ def _generate_completions(self, model, prompts):
463
+ # Generate completions from the policy model.
464
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
465
+ model_output = unwrapped_policy_for_gen_ctx.generate(
466
+ input_ids=prompts["input_ids"],
467
+ attention_mask=prompts["attention_mask"],
468
+ generation_config=self.generation_config,
469
+ )
470
+
471
+ # Get the DDP/FSDP unwrapped version of the main model.
472
+ # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
473
+ policy_model_for_gmw = self.accelerator.unwrap_model(model)
474
+
475
+ # Determine the correct reference model for GeometricMixtureWrapper.
476
+ # This also needs to be DDP/FSDP unwrapped.
477
+ ref_model_for_gmw: torch.nn.Module
478
+ if self.ref_model is None:
479
+ # No explicit ref_model is provided.
480
+ # Use the base of the main `model` if it's a PEFT model.
481
+ # policy_model_for_gmw is already DDP-unwrapped.
482
+ if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
483
+ ref_model_for_gmw = policy_model_for_gmw.get_base_model()
484
+ else:
485
+ # Not a PEFT model (or PEFT not available), or already a base model.
486
+ # Use the DDP-unwrapped policy model itself as the reference.
487
+ ref_model_for_gmw = policy_model_for_gmw
488
+ else:
489
+ # An explicit ref_model is provided. Unwrap it for DDP/FSDP.
490
+ ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
491
+
492
+ # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
493
+ with torch.no_grad(): # Ensure no_grad context for mixture model generation
494
+ mixture_model = GeometricMixtureWrapper(
495
+ model=policy_model_for_gmw,
496
+ ref_model=ref_model_for_gmw,
497
+ generation_config=self.generation_config,
498
+ mixture_coef=self.mixture_coef,
499
+ device=self.accelerator.device,
500
+ )
501
+
502
+ mixture_output = mixture_model.generate(
503
+ input_ids=prompts["input_ids"],
504
+ attention_mask=prompts["attention_mask"],
505
+ generation_config=self.generation_config,
506
+ )
507
+
508
+ return model_output, mixture_output
509
+
510
+ def _process_completions(self, model_output, mixture_output, prompts):
511
+ context_length = prompts["input_ids"].shape[1]
512
+
513
+ # Process model completions
514
+ model_completion_ids = model_output[:, context_length:]
515
+ model_completion_ids, model_completion_mask = truncate_right(
516
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
517
+ )
518
+ model_data = {
519
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
520
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
521
+ "raw": prompts["raw"],
522
+ }
523
+
524
+ # Process reference model completions
525
+ mixture_completion_ids = mixture_output[:, context_length:]
526
+ mixture_completion_ids, mixture_completion_mask = truncate_right(
527
+ mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
528
+ )
529
+ mixture_data = {
530
+ "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
531
+ "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
532
+ "raw": prompts["raw"],
533
+ }
534
+
535
+ return model_data, mixture_data
536
+
537
+ def _compute_rewards(self, model_data, mixture_data, context_length):
538
+ with torch.no_grad():
539
+ _, model_scores, _ = get_reward(
540
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
541
+ )
542
+ _, mixture_scores, _ = get_reward(
543
+ self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
544
+ )
545
+
546
+ # Apply EOS penalty if needed
547
+ if self.args.missing_eos_penalty is not None:
548
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
549
+ mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
550
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
551
+ mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
552
+
553
+ return model_scores, mixture_scores
554
+
555
+ def _compute_judge(self, model_data, mixture_data, context_length):
556
+ prompts = model_data["raw"]
557
+ model_data_completions = self.processing_class.batch_decode(
558
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
559
+ )
560
+ model_data_completions = [completion.strip() for completion in model_data_completions]
561
+
562
+ mixture_data_completions = self.processing_class.batch_decode(
563
+ mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
564
+ )
565
+ mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
566
+ if is_conversational({"prompt": prompts[0]}):
567
+ model_data_completions = [
568
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
569
+ ]
570
+ environment = jinja2.Environment()
571
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
572
+ prompts = [template.render(messages=message) for message in prompts]
573
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
574
+
575
+ mixture_data_completions = [
576
+ [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
577
+ ]
578
+ mixture_data_completions = [
579
+ template.render(messages=completion) for completion in mixture_data_completions
580
+ ]
581
+
582
+ probability = self.judge.judge(
583
+ prompts,
584
+ list(zip(model_data_completions, mixture_data_completions)),
585
+ return_scores=True,
586
+ )
587
+ return torch.tensor(probability, device=model_data["input_ids"].device)
588
+
589
+ def _compute_logprobs(self, model, model_data, context_length):
590
+ def compute_logprobs_for_data(m, data):
591
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
592
+ logits = output.logits[:, context_length - 1 : -1]
593
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
594
+ return token_logprobs
595
+
596
+ # Compute logprobs for model completions under the model
597
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
598
+
599
+ # Compute logprobs of model completions under the reference model
600
+ with torch.no_grad():
601
+ if self.ref_model is None:
602
+ with model.disable_adapter():
603
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
604
+ else:
605
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
606
+
607
+ # Mask padding tokens
608
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
609
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
610
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
611
+
612
+ return (model_logprobs_model_data, ref_logprobs_model_data)
613
+
614
+ def _compute_losses(
615
+ self,
616
+ model_logprobs_model_data,
617
+ ref_logprobs_model_data,
618
+ probability,
619
+ ):
620
+ # reinforce score where 0.5 is a control variate
621
+ score = (probability - 0.5) * model_logprobs_model_data.sum(1)
622
+
623
+ # kl divergence via reinforce
624
+ with torch.no_grad():
625
+ log_ratio = model_logprobs_model_data - ref_logprobs_model_data
626
+ kl_div_log = log_ratio.sum(1)
627
+ kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
628
+
629
+ # final loss
630
+ loss = self.beta * kl_div_loss - score
631
+
632
+ return loss.mean(), score, kl_div_log
633
+
634
+ def _log_statistics(
635
+ self,
636
+ model_data,
637
+ mixture_data,
638
+ model_logprobs_model_data,
639
+ ref_logprobs_model_data,
640
+ probability,
641
+ score,
642
+ kl_div,
643
+ context_length,
644
+ model_scores=None,
645
+ mixture_scores=None,
646
+ ):
647
+ # Helper function to gather and compute mean
648
+ def gather_mean(tensor):
649
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
650
+
651
+ # Log score
652
+ self.stats["loss/score"].append(gather_mean(score))
653
+ # Log KL divergence
654
+ self.stats["loss/kl"].append(gather_mean(kl_div))
655
+
656
+ # Log logprobs
657
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
658
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
659
+
660
+ self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
661
+ self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
662
+
663
+ # Log rewards
664
+ if self.reward_model is not None:
665
+ self.stats["rewards/chosen"].append(gather_mean(model_scores))
666
+ self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
667
+
668
+ # Log probabilities
669
+ self.stats["rewards/probabilities"].append(gather_mean(probability))
670
+
671
+ # Calculate entropy for model data
672
+ entropy_model_data = -model_logprobs_model_data.sum(1)
673
+ self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
674
+
675
+ # Calculate margins
676
+ margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
677
+ self.stats["rewards/margins"].append(gather_mean(margin))
678
+
679
+ # Calculate accuracy
680
+ accuracy = (margin > 0).float()
681
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy))
682
+
683
+ # Log EOS token statistics
684
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
685
+ mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
686
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
687
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
688
+
689
+ # Log beta and mixture coef
690
+ self.stats["beta"].append(self.beta)
691
+ self.stats["mixture_coef"].append(self.mixture_coef)
692
+
693
+ def training_step(
694
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
695
+ ) -> torch.Tensor:
696
+ model.train()
697
+
698
+ # Apply chat template and tokenize the input
699
+ batch_size = len(next(iter(inputs.values())))
700
+ prompts = inputs["prompt"]
701
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
702
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
703
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
704
+ inputs = self.data_collator(inputs)
705
+
706
+ # need the prompt_ only
707
+ inputs = self._prepare_inputs(inputs)
708
+ context_length = inputs["prompt_input_ids"].shape[1]
709
+ prompts = {
710
+ "input_ids": inputs["prompt_input_ids"],
711
+ "attention_mask": inputs["prompt_attention_mask"],
712
+ "raw": prompts,
713
+ }
714
+ del inputs
715
+
716
+ # Sample completions from both the model and the reference model
717
+ model_output, mixture_output = self._generate_completions(model, prompts)
718
+
719
+ # Process model completions
720
+ model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
721
+
722
+ # Compute rewards
723
+ if self.reward_model is not None:
724
+ model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
725
+ # probability of the model data vs the mixture data
726
+ probability = F.sigmoid(model_scores - mixture_scores)
727
+ else:
728
+ model_scores, mixture_scores = None, None
729
+ probability = self._compute_judge(model_data, mixture_data, context_length)
730
+
731
+ # Compute logprobs
732
+ model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
733
+
734
+ # Compute loss
735
+ loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
736
+
737
+ # Log everything
738
+ self._log_statistics(
739
+ model_data,
740
+ mixture_data,
741
+ model_logprobs_model_data.detach(),
742
+ ref_logprobs_model_data,
743
+ probability,
744
+ score.detach(),
745
+ kl_div.detach(),
746
+ context_length,
747
+ model_scores,
748
+ mixture_scores,
749
+ )
750
+
751
+ if (
752
+ self.args.torch_empty_cache_steps is not None
753
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
754
+ ):
755
+ empty_cache()
756
+
757
+ kwargs = {}
758
+ # For LOMO optimizers you need to explicitly use the learning rate
759
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
760
+ kwargs["learning_rate"] = self._get_learning_rate()
761
+
762
+ if self.args.n_gpu > 1:
763
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
764
+
765
+ if self.use_apex:
766
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
767
+ scaled_loss.backward()
768
+ else:
769
+ self.accelerator.backward(loss, **kwargs)
770
+
771
+ return loss.detach() / self.args.gradient_accumulation_steps
772
+
773
+ def create_model_card(
774
+ self,
775
+ model_name: Optional[str] = None,
776
+ dataset_name: Optional[str] = None,
777
+ tags: Union[str, list[str], None] = None,
778
+ ):
779
+ """
780
+ Creates a draft of a model card using the information available to the `Trainer`.
781
+
782
+ Args:
783
+ model_name (`str` or `None`, *optional*, defaults to `None`):
784
+ Name of the model.
785
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
786
+ Name of the dataset used for training.
787
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
788
+ Tags to be associated with the model card.
789
+ """
790
+ if not self.is_world_process_zero():
791
+ return
792
+
793
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
794
+ base_model = self.model.config._name_or_path
795
+ else:
796
+ base_model = None
797
+
798
+ # normalize `tags` to a mutable set
799
+ if tags is None:
800
+ tags = set()
801
+ elif isinstance(tags, str):
802
+ tags = {tags}
803
+ else:
804
+ tags = set(tags)
805
+
806
+ if hasattr(self.model.config, "unsloth_version"):
807
+ tags.add("unsloth")
808
+
809
+ tags.update(self._tag_names)
810
+
811
+ citation = textwrap.dedent("""\
812
+ @inproceedings{munos2024nash,
813
+ title = {{Nash Learning from Human Feedback}},
814
+ author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
815
+ year = 2024,
816
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
817
+ publisher = {OpenReview.net},
818
+ url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
819
+ }""")
820
+
821
+ model_card = generate_model_card(
822
+ base_model=base_model,
823
+ model_name=model_name,
824
+ hub_model_id=self.hub_model_id,
825
+ dataset_name=dataset_name,
826
+ tags=tags,
827
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
828
+ comet_url=get_comet_experiment_url(),
829
+ trainer_name="Nash-MD",
830
+ trainer_citation=citation,
831
+ paper_title="Nash Learning from Human Feedback",
832
+ paper_id="2312.00886",
833
+ )
834
+
835
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
836
+ class UnslothNashMDTrainer(_UnslothNashMDTrainer):
837
+ """
838
+
839
+ Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
840
+
841
+ Args:
842
+ model (`transformers.PreTrainedModel`):
843
+ The model to train, preferably an `AutoModelForCausalLM`.
844
+ ref_model (`PreTrainedModelWrapper`):
845
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
846
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
847
+ architecture as the model to be optimized.
848
+ reward_model (`transformers.PreTrainedModel`):
849
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
850
+ judge (`BasePairwiseJudge`):
851
+ The judge to use for pairwise comparison of model completions.
852
+ args (`NashMDConfig`):
853
+ The NashMD config arguments to use for training.
854
+ data_collator (`transformers.DataCollator`):
855
+ The data collator to use for training. If None is specified, the default data collator
856
+ (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
857
+ sequences in the batch, given a dataset of paired sequences.
858
+ train_dataset (`datasets.Dataset`):
859
+ The dataset to use for training.
860
+ eval_dataset (`datasets.Dataset`):
861
+ The dataset to use for evaluation.
862
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
863
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
864
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
865
+ reuse the fine-tuned model.
866
+ peft_config (`dict`):
867
+ The peft config to use for training.
868
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
869
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
870
+ metric values.
871
+ callbacks (`list[transformers.TrainerCallback]`):
872
+ The callbacks to use for training.
873
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
874
+ The optimizer and scheduler to use for training.
875
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
876
+ The function to use to preprocess the logits before computing the metrics.
877
+
878
+ """
879
+ def __init__(
880
+ self,
881
+ model = None,
882
+ ref_model = None,
883
+ reward_model = None,
884
+ judge = None,
885
+ args = None,
886
+ data_collator = None,
887
+ train_dataset = None,
888
+ eval_dataset = None,
889
+ processing_class = None,
890
+ peft_config = None,
891
+ compute_metrics = None,
892
+ callbacks = None,
893
+ preprocess_logits_for_metrics = None,
894
+ **kwargs
895
+ ):
896
+ if args is None: args = UnslothNashMDConfig()
897
+ use_bf16 = getattr(args, 'bf16', False)
898
+ if type(use_bf16) is not bool: use_bf16 = False
899
+ use_fp16 = getattr(args, 'fp16', False)
900
+ if type(use_fp16) is not bool: use_fp16 = False
901
+ force_float32 = False
902
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
903
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
904
+ force_float32 = True
905
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
906
+ dtype = getattr(model.config, 'torch_dtype', None)
907
+ if dtype is None: dtype = model.get_input_embeddings().dtype
908
+ from unsloth_zoo.utils import _get_dtype
909
+ dtype = _get_dtype(dtype)
910
+ float16 = dtype == torch.float16
911
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
912
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
913
+ if force_float32:
914
+ args.fp16 = False
915
+ args.bf16 = False
916
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
917
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
918
+ args.fp16 = float16
919
+ args.bf16 = not float16
920
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
921
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
922
+ args.eval_strategy = 'steps'
923
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
924
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
925
+ if ga_steps is not None and ga_steps > 1:
926
+ from transformers import __version__ as transformers_version
927
+ if Version(transformers_version) <= Version('4.45.2'):
928
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
929
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
930
+ if getattr(args, 'eval_strategy', 'no') != 'no':
931
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
932
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
933
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
934
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
935
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
936
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
937
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
938
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
939
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
940
+ if force_float32:
941
+ args.bf16_full_eval = False
942
+ args.fp16_full_eval = False
943
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
944
+ args.bf16_full_eval = True
945
+ args.fp16_full_eval = False
946
+ elif not bf16_full_eval and not fp16_full_eval:
947
+ args.bf16_full_eval = args.bf16
948
+ args.fp16_full_eval = args.fp16
949
+ _output_logits = False
950
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
951
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
952
+ if _output_logits:
953
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
954
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
955
+ pass
956
+ else:
957
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
958
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
959
+ if args_max_seq_length is None and model_max_seq_length is not None:
960
+ max_seq_length = model.max_seq_length
961
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
962
+ if model is not None and hasattr(model, 'for_training'):
963
+ model.for_training()
964
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
965
+ if 'processing_class' in locals():
966
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
967
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
968
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
969
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
970
+ if not isinstance(data_collator, UnslothVisionDataCollator):
971
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
972
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
973
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
974
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
975
+ else:
976
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
977
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
978
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
979
+ if not isinstance(data_collator, UnslothVisionDataCollator):
980
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
981
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
982
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
983
+ else:
984
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
985
+ other_metrics = []
986
+
987
+ from unsloth_zoo.logging_utils import PatchRLStatistics
988
+ PatchRLStatistics('nash_md_trainer', other_metrics)
989
+
990
+ super().__init__(
991
+ model = model,
992
+ ref_model = ref_model,
993
+ reward_model = reward_model,
994
+ judge = judge,
995
+ args = args,
996
+ data_collator = data_collator,
997
+ train_dataset = train_dataset,
998
+ eval_dataset = eval_dataset,
999
+ processing_class = processing_class,
1000
+ peft_config = peft_config,
1001
+ compute_metrics = compute_metrics,
1002
+ callbacks = callbacks,
1003
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1004
+ if hasattr(self, 'neftune_hook_handle'):
1005
+ self.neftune_hook_handle.remove()
1006
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1007
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1008
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1009
+ pass
1010
+ if hasattr(self, 'accelerator'):
1011
+ scaler = self.accelerator.scaler
1012
+ current_model = model
1013
+ while hasattr(current_model, 'model'):
1014
+ current_model.accelerator_scaler = scaler
1015
+ current_model = current_model.model
1016
+ current_model.accelerator_scaler = scaler
1017
+ pass
1018
+
1019
+ pass
unsloth_compiled_cache/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothORPOConfig(ORPOConfig):
53
+ """
54
+
55
+ Configuration class for the [`ORPOTrainer`].
56
+
57
+ This class includes only the parameters that are specific to ORPO training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
59
+ differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
67
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
68
+ to use the default data collator.
69
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
70
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
71
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
72
+ Maximum length of the completion. This argument is required if you want to use the default data collator
73
+ and your model is an encoder-decoder.
74
+ beta (`float`, *optional*, defaults to `0.1`):
75
+ Parameter controlling the relative ratio loss weight in the ORPO loss. In the
76
+ [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the
77
+ [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
78
+ disable_dropout (`bool`, *optional*, defaults to `True`):
79
+ Whether to disable dropout in the model.
80
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
81
+ Label pad token id. This argument is required if you want to use the default data collator.
82
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
83
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
84
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
85
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
86
+ This argument is required if you want to use the default data collator.
87
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
88
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
89
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
90
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
91
+ you need to specify if the model returned by the callable is an encoder-decoder model.
92
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
93
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
94
+ string.
95
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
96
+ Number of processes to use for processing the dataset.
97
+
98
+ """
99
+ vllm_sampling_params: Optional[Any] = field(
100
+ default = None,
101
+ metadata = {'help': 'vLLM SamplingParams'},
102
+ )
103
+ unsloth_num_chunks : Optional[int] = field(
104
+ default = -1,
105
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
106
+ )
107
+ max_seq_length : Optional[int] = field(
108
+ default = None,
109
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
110
+ )
111
+ def __init__(
112
+ self,
113
+ output_dir = None,
114
+ overwrite_output_dir = None,
115
+ do_train = False,
116
+ do_eval = False,
117
+ do_predict = False,
118
+ eval_strategy = 'no',
119
+ prediction_loss_only = False,
120
+ per_device_train_batch_size = 4,
121
+ per_device_eval_batch_size = 4,
122
+ per_gpu_train_batch_size = None,
123
+ per_gpu_eval_batch_size = None,
124
+ gradient_accumulation_steps = 2,
125
+ eval_accumulation_steps = 2,
126
+ eval_delay = 0,
127
+ torch_empty_cache_steps = 250,
128
+ learning_rate = 5e-05,
129
+ weight_decay = 0.01,
130
+ adam_beta1 = 0.9,
131
+ adam_beta2 = 0.999,
132
+ adam_epsilon = 1e-08,
133
+ max_grad_norm = 1.0,
134
+ num_train_epochs = 3.0,
135
+ max_steps = -1,
136
+ lr_scheduler_type = 'linear',
137
+ warmup_ratio = 0.1,
138
+ warmup_steps = 0,
139
+ log_level = 'passive',
140
+ log_level_replica = 'warning',
141
+ log_on_each_node = True,
142
+ logging_dir = None,
143
+ logging_strategy = 'steps',
144
+ logging_first_step = False,
145
+ logging_steps = 1,
146
+ logging_nan_inf_filter = False,
147
+ save_strategy = 'steps',
148
+ save_steps = 500,
149
+ save_total_limit = None,
150
+ save_safetensors = True,
151
+ save_on_each_node = False,
152
+ save_only_model = False,
153
+ restore_callback_states_from_checkpoint = False,
154
+ no_cuda = False,
155
+ use_cpu = False,
156
+ use_mps_device = False,
157
+ seed = 3407,
158
+ data_seed = 3407,
159
+ jit_mode_eval = False,
160
+ use_ipex = False,
161
+ bf16 = False,
162
+ fp16 = False,
163
+ fp16_opt_level = 'O1',
164
+ half_precision_backend = 'auto',
165
+ bf16_full_eval = False,
166
+ fp16_full_eval = False,
167
+ tf32 = None,
168
+ local_rank = -1,
169
+ ddp_backend = None,
170
+ tpu_num_cores = None,
171
+ tpu_metrics_debug = False,
172
+ debug = '',
173
+ dataloader_drop_last = False,
174
+ eval_steps = None,
175
+ dataloader_num_workers = 0,
176
+ dataloader_prefetch_factor = None,
177
+ past_index = -1,
178
+ run_name = None,
179
+ disable_tqdm = None,
180
+ remove_unused_columns = True,
181
+ label_names = None,
182
+ load_best_model_at_end = False,
183
+ metric_for_best_model = None,
184
+ greater_is_better = None,
185
+ ignore_data_skip = False,
186
+ fsdp = '',
187
+ fsdp_min_num_params = 0,
188
+ fsdp_config = None,
189
+ fsdp_transformer_layer_cls_to_wrap = None,
190
+ accelerator_config = None,
191
+ deepspeed = None,
192
+ label_smoothing_factor = 0.0,
193
+ optim = 'adamw_8bit',
194
+ optim_args = None,
195
+ adafactor = False,
196
+ group_by_length = False,
197
+ length_column_name = 'length',
198
+ report_to = None,
199
+ ddp_find_unused_parameters = None,
200
+ ddp_bucket_cap_mb = None,
201
+ ddp_broadcast_buffers = None,
202
+ dataloader_pin_memory = True,
203
+ dataloader_persistent_workers = False,
204
+ skip_memory_metrics = True,
205
+ use_legacy_prediction_loop = False,
206
+ push_to_hub = False,
207
+ resume_from_checkpoint = None,
208
+ hub_model_id = None,
209
+ hub_strategy = 'every_save',
210
+ hub_token = None,
211
+ hub_private_repo = None,
212
+ hub_always_push = False,
213
+ hub_revision = None,
214
+ gradient_checkpointing = False,
215
+ gradient_checkpointing_kwargs = None,
216
+ include_inputs_for_metrics = False,
217
+ eval_do_concat_batches = True,
218
+ fp16_backend = 'auto',
219
+ push_to_hub_model_id = None,
220
+ push_to_hub_organization = None,
221
+ push_to_hub_token = None,
222
+ mp_parameters = '',
223
+ auto_find_batch_size = True,
224
+ full_determinism = False,
225
+ torchdynamo = None,
226
+ ray_scope = 'last',
227
+ ddp_timeout = 1800,
228
+ torch_compile = False,
229
+ torch_compile_backend = None,
230
+ torch_compile_mode = None,
231
+ include_tokens_per_second = False,
232
+ include_num_input_tokens_seen = False,
233
+ neftune_noise_alpha = None,
234
+ optim_target_modules = None,
235
+ batch_eval_metrics = False,
236
+ eval_on_start = False,
237
+ use_liger_kernel = False,
238
+ liger_kernel_config = None,
239
+ eval_use_gather_object = False,
240
+ average_tokens_across_devices = True,
241
+ max_length = 1024,
242
+ max_prompt_length = 512,
243
+ max_completion_length = None,
244
+ beta = 0.1,
245
+ disable_dropout = True,
246
+ label_pad_token_id = -100,
247
+ padding_value = None,
248
+ truncation_mode = 'keep_end',
249
+ generate_during_eval = False,
250
+ is_encoder_decoder = None,
251
+ model_init_kwargs = None,
252
+ dataset_num_proc = None,
253
+ vllm_sampling_params = None,
254
+ unsloth_num_chunks = -1,
255
+ max_seq_length = None,
256
+ **kwargs,
257
+ ):
258
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
259
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
260
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
261
+ output_dir = 'unsloth_training_checkpoints'
262
+ save_strategy = 'no'
263
+ if dataset_num_proc is None:
264
+ from multiprocessing import cpu_count
265
+ dataset_num_proc = min(cpu_count()*2, 2)
266
+
267
+ super().__init__(
268
+ output_dir = output_dir,
269
+ overwrite_output_dir = overwrite_output_dir,
270
+ do_train = do_train,
271
+ do_eval = do_eval,
272
+ do_predict = do_predict,
273
+ eval_strategy = eval_strategy,
274
+ prediction_loss_only = prediction_loss_only,
275
+ per_device_train_batch_size = per_device_train_batch_size,
276
+ per_device_eval_batch_size = per_device_eval_batch_size,
277
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
278
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
279
+ gradient_accumulation_steps = gradient_accumulation_steps,
280
+ eval_accumulation_steps = eval_accumulation_steps,
281
+ eval_delay = eval_delay,
282
+ torch_empty_cache_steps = torch_empty_cache_steps,
283
+ learning_rate = learning_rate,
284
+ weight_decay = weight_decay,
285
+ adam_beta1 = adam_beta1,
286
+ adam_beta2 = adam_beta2,
287
+ adam_epsilon = adam_epsilon,
288
+ max_grad_norm = max_grad_norm,
289
+ num_train_epochs = num_train_epochs,
290
+ max_steps = max_steps,
291
+ lr_scheduler_type = lr_scheduler_type,
292
+ warmup_ratio = warmup_ratio,
293
+ warmup_steps = warmup_steps,
294
+ log_level = log_level,
295
+ log_level_replica = log_level_replica,
296
+ log_on_each_node = log_on_each_node,
297
+ logging_dir = logging_dir,
298
+ logging_strategy = logging_strategy,
299
+ logging_first_step = logging_first_step,
300
+ logging_steps = logging_steps,
301
+ logging_nan_inf_filter = logging_nan_inf_filter,
302
+ save_strategy = save_strategy,
303
+ save_steps = save_steps,
304
+ save_total_limit = save_total_limit,
305
+ save_safetensors = save_safetensors,
306
+ save_on_each_node = save_on_each_node,
307
+ save_only_model = save_only_model,
308
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
309
+ no_cuda = no_cuda,
310
+ use_cpu = use_cpu,
311
+ use_mps_device = use_mps_device,
312
+ seed = seed,
313
+ data_seed = data_seed,
314
+ jit_mode_eval = jit_mode_eval,
315
+ use_ipex = use_ipex,
316
+ bf16 = bf16,
317
+ fp16 = fp16,
318
+ fp16_opt_level = fp16_opt_level,
319
+ half_precision_backend = half_precision_backend,
320
+ bf16_full_eval = bf16_full_eval,
321
+ fp16_full_eval = fp16_full_eval,
322
+ tf32 = tf32,
323
+ local_rank = local_rank,
324
+ ddp_backend = ddp_backend,
325
+ tpu_num_cores = tpu_num_cores,
326
+ tpu_metrics_debug = tpu_metrics_debug,
327
+ debug = debug,
328
+ dataloader_drop_last = dataloader_drop_last,
329
+ eval_steps = eval_steps,
330
+ dataloader_num_workers = dataloader_num_workers,
331
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
332
+ past_index = past_index,
333
+ run_name = run_name,
334
+ disable_tqdm = disable_tqdm,
335
+ remove_unused_columns = remove_unused_columns,
336
+ label_names = label_names,
337
+ load_best_model_at_end = load_best_model_at_end,
338
+ metric_for_best_model = metric_for_best_model,
339
+ greater_is_better = greater_is_better,
340
+ ignore_data_skip = ignore_data_skip,
341
+ fsdp = fsdp,
342
+ fsdp_min_num_params = fsdp_min_num_params,
343
+ fsdp_config = fsdp_config,
344
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
345
+ accelerator_config = accelerator_config,
346
+ deepspeed = deepspeed,
347
+ label_smoothing_factor = label_smoothing_factor,
348
+ optim = optim,
349
+ optim_args = optim_args,
350
+ adafactor = adafactor,
351
+ group_by_length = group_by_length,
352
+ length_column_name = length_column_name,
353
+ report_to = report_to,
354
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
355
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
356
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
357
+ dataloader_pin_memory = dataloader_pin_memory,
358
+ dataloader_persistent_workers = dataloader_persistent_workers,
359
+ skip_memory_metrics = skip_memory_metrics,
360
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
361
+ push_to_hub = push_to_hub,
362
+ resume_from_checkpoint = resume_from_checkpoint,
363
+ hub_model_id = hub_model_id,
364
+ hub_strategy = hub_strategy,
365
+ hub_token = hub_token,
366
+ hub_private_repo = hub_private_repo,
367
+ hub_always_push = hub_always_push,
368
+ hub_revision = hub_revision,
369
+ gradient_checkpointing = gradient_checkpointing,
370
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
371
+ include_inputs_for_metrics = include_inputs_for_metrics,
372
+ eval_do_concat_batches = eval_do_concat_batches,
373
+ fp16_backend = fp16_backend,
374
+ push_to_hub_model_id = push_to_hub_model_id,
375
+ push_to_hub_organization = push_to_hub_organization,
376
+ push_to_hub_token = push_to_hub_token,
377
+ mp_parameters = mp_parameters,
378
+ auto_find_batch_size = auto_find_batch_size,
379
+ full_determinism = full_determinism,
380
+ torchdynamo = torchdynamo,
381
+ ray_scope = ray_scope,
382
+ ddp_timeout = ddp_timeout,
383
+ torch_compile = torch_compile,
384
+ torch_compile_backend = torch_compile_backend,
385
+ torch_compile_mode = torch_compile_mode,
386
+ include_tokens_per_second = include_tokens_per_second,
387
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
388
+ neftune_noise_alpha = neftune_noise_alpha,
389
+ optim_target_modules = optim_target_modules,
390
+ batch_eval_metrics = batch_eval_metrics,
391
+ eval_on_start = eval_on_start,
392
+ use_liger_kernel = use_liger_kernel,
393
+ liger_kernel_config = liger_kernel_config,
394
+ eval_use_gather_object = eval_use_gather_object,
395
+ average_tokens_across_devices = average_tokens_across_devices,
396
+ max_length = max_length,
397
+ max_prompt_length = max_prompt_length,
398
+ max_completion_length = max_completion_length,
399
+ beta = beta,
400
+ disable_dropout = disable_dropout,
401
+ label_pad_token_id = label_pad_token_id,
402
+ padding_value = padding_value,
403
+ truncation_mode = truncation_mode,
404
+ generate_during_eval = generate_during_eval,
405
+ is_encoder_decoder = is_encoder_decoder,
406
+ model_init_kwargs = model_init_kwargs,
407
+ dataset_num_proc = dataset_num_proc,**kwargs)
408
+ self.vllm_sampling_params = vllm_sampling_params
409
+ self.unsloth_num_chunks = unsloth_num_chunks
410
+ self.max_seq_length = max_seq_length
411
+ pass
412
+
413
+ class _UnslothORPOTrainer(Trainer):
414
+ r""""""
415
+
416
+ _tag_names = ["trl", "orpo"]
417
+
418
+ def __init__(
419
+ self,
420
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
421
+ args: Optional[ORPOConfig] = None,
422
+ data_collator: Optional[DataCollator] = None,
423
+ train_dataset: Optional[Dataset] = None,
424
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
425
+ processing_class: Optional[
426
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
427
+ ] = None,
428
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
429
+ callbacks: Optional[list[TrainerCallback]] = None,
430
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
431
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
432
+ peft_config: Optional[dict] = None,
433
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
434
+ ):
435
+ if args.model_init_kwargs is None:
436
+ model_init_kwargs = {}
437
+ elif not isinstance(model, str):
438
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
439
+ else:
440
+ model_init_kwargs = args.model_init_kwargs
441
+ torch_dtype = model_init_kwargs.get("torch_dtype")
442
+ if torch_dtype is not None:
443
+ # Convert to `torch.dtype` if an str is passed
444
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
445
+ torch_dtype = getattr(torch, torch_dtype)
446
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
447
+ raise ValueError(
448
+ f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
449
+ )
450
+ model_init_kwargs["torch_dtype"] = torch_dtype
451
+
452
+ if isinstance(model, str):
453
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
454
+
455
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
456
+ # has been called in order to properly call autocast if needed.
457
+ self._peft_has_been_casted_to_bf16 = False
458
+
459
+ if not is_peft_available() and peft_config is not None:
460
+ raise ValueError(
461
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
462
+ )
463
+ elif is_peft_available() and peft_config is not None:
464
+ # if model is a peft model and we have a peft_config, we merge and unload it first
465
+ if isinstance(model, PeftModel):
466
+ model = model.merge_and_unload()
467
+
468
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
469
+ _support_gc_kwargs = hasattr(
470
+ args, "gradient_checkpointing_kwargs"
471
+ ) and "gradient_checkpointing_kwargs" in list(
472
+ inspect.signature(prepare_model_for_kbit_training).parameters
473
+ )
474
+
475
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
476
+
477
+ if _support_gc_kwargs:
478
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
479
+
480
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
481
+ elif args.gradient_checkpointing:
482
+ # For backward compatibility with older versions of transformers
483
+ if hasattr(model, "enable_input_require_grads"):
484
+ model.enable_input_require_grads()
485
+ else:
486
+
487
+ def make_inputs_require_grad(module, input, output):
488
+ output.requires_grad_(True)
489
+
490
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
491
+
492
+ # get peft model with the given config
493
+ model = model
494
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
495
+ peft_module_casting_to_bf16(model)
496
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
497
+ self._peft_has_been_casted_to_bf16 = True
498
+
499
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
500
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
501
+ # fail or completely fail.
502
+ elif args.gradient_checkpointing:
503
+ # For backward compatibility with older versions of transformers
504
+ if hasattr(model, "enable_input_require_grads"):
505
+ model.enable_input_require_grads()
506
+ else:
507
+
508
+ def make_inputs_require_grad(module, input, output):
509
+ output.requires_grad_(True)
510
+
511
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
512
+
513
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
514
+ raise ValueError(
515
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
516
+ " Please install `wandb` or `comet-ml` to resolve."
517
+ )
518
+
519
+ if model is not None:
520
+ self.is_encoder_decoder = model.config.is_encoder_decoder
521
+ elif args.is_encoder_decoder is None:
522
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
523
+ else:
524
+ self.is_encoder_decoder = args.is_encoder_decoder
525
+
526
+ if self.is_encoder_decoder:
527
+ self.decoder_start_token_id = model.config.decoder_start_token_id
528
+ self.pad_token_id = model.config.pad_token_id
529
+
530
+ if processing_class is None:
531
+ raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
532
+ if args.max_length is None:
533
+ warnings.warn(
534
+ "`max_length` is not set in the ORPOConfig's init"
535
+ " it will default to `512` by default, but you should do it yourself in the future.",
536
+ UserWarning,
537
+ )
538
+ max_length = 512
539
+ else:
540
+ max_length = args.max_length
541
+ if args.max_prompt_length is None:
542
+ warnings.warn(
543
+ "`max_prompt_length` is not set in the ORPOConfig's init"
544
+ " it will default to `128` by default, but you should do it yourself in the future.",
545
+ UserWarning,
546
+ )
547
+ max_prompt_length = 128
548
+ else:
549
+ max_prompt_length = args.max_prompt_length
550
+
551
+ if args.max_completion_length is None and self.is_encoder_decoder:
552
+ warnings.warn(
553
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
554
+ " it will default to `128` by default, but you should do it yourself in the future.",
555
+ UserWarning,
556
+ )
557
+ self.max_completion_length = 128
558
+ else:
559
+ self.max_completion_length = args.max_completion_length
560
+
561
+ if data_collator is None:
562
+ data_collator = DPODataCollatorWithPadding(
563
+ pad_token_id=processing_class.pad_token_id,
564
+ label_pad_token_id=args.label_pad_token_id,
565
+ is_encoder_decoder=self.is_encoder_decoder,
566
+ )
567
+
568
+ if args.remove_unused_columns:
569
+ args.remove_unused_columns = False
570
+ # warn users
571
+ warnings.warn(
572
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
573
+ " we have set it for you, but you should do it yourself in the future.",
574
+ UserWarning,
575
+ )
576
+
577
+ self.use_dpo_data_collator = True
578
+ else:
579
+ self.use_dpo_data_collator = False
580
+
581
+ # Disable dropout in the model and reference model
582
+ if args.disable_dropout:
583
+ disable_dropout_in_model(model)
584
+
585
+ self.max_length = max_length
586
+ self.generate_during_eval = args.generate_during_eval
587
+ self.label_pad_token_id = args.label_pad_token_id
588
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
589
+ self.max_prompt_length = max_prompt_length
590
+ self.truncation_mode = args.truncation_mode
591
+ self.processing_class = processing_class
592
+
593
+ self.beta = args.beta
594
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
595
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
596
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
597
+ warnings.warn(
598
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
599
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
600
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
601
+ "loss.",
602
+ UserWarning,
603
+ )
604
+
605
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
606
+
607
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
608
+ # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
609
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
610
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
611
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
612
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
613
+ # that the warning has already been issued.
614
+ model.warnings_issued["estimate_tokens"] = True
615
+
616
+ # Compute that only on the main process for faster data processing.
617
+ # see: https://github.com/huggingface/trl/pull/1255
618
+ with PartialState().main_process_first():
619
+ # Extract the prompt if needed, and apply the chat template if needed
620
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
621
+ train_dataset = train_dataset.map(
622
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
623
+ )
624
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
625
+ if eval_dataset is not None:
626
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
627
+ eval_dataset = eval_dataset.map(
628
+ maybe_apply_chat_template,
629
+ fn_kwargs={"tokenizer": processing_class},
630
+ num_proc=args.dataset_num_proc,
631
+ )
632
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
633
+
634
+ super().__init__(
635
+ model=model,
636
+ args=args,
637
+ data_collator=data_collator,
638
+ train_dataset=train_dataset,
639
+ eval_dataset=eval_dataset,
640
+ processing_class=processing_class,
641
+ model_init=model_init,
642
+ compute_metrics=compute_metrics,
643
+ callbacks=callbacks,
644
+ optimizers=optimizers,
645
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
646
+ )
647
+
648
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
649
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
650
+ # self.model_accepts_loss_kwargs to False to enable scaling.
651
+ self.model_accepts_loss_kwargs = False
652
+
653
+ # Add tags for models that have been loaded with the correct transformers version
654
+ if hasattr(self.model, "add_model_tags"):
655
+ self.model.add_model_tags(self._tag_names)
656
+
657
+ if not hasattr(self, "accelerator"):
658
+ raise AttributeError(
659
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
660
+ )
661
+
662
+ def build_tokenized_answer(self, prompt, answer):
663
+ """
664
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
665
+ b)[len(enc(a)):]`. Reference:
666
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
667
+ """
668
+
669
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
670
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
671
+
672
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
673
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
674
+
675
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
676
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
677
+
678
+ # Prepare input tokens for token by token comparison
679
+ full_input_ids = np.array(full_tokenized["input_ids"])
680
+
681
+ if len(full_input_ids) != len(full_concat_input_ids):
682
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
683
+
684
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
685
+ # can be merged together when tokenizing prompt+answer. This could result
686
+ # on the last token from the prompt being different when tokenized on its own
687
+ # vs when done as prompt+answer.
688
+ response_token_ids_start_idx = len(prompt_input_ids)
689
+
690
+ # If tokenized prompt is different than both prompt+answer, then it means the
691
+ # last token has changed due to merging.
692
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
693
+ response_token_ids_start_idx -= 1
694
+
695
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
696
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
697
+
698
+ if len(prompt_input_ids) != len(prompt_attention_mask):
699
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
700
+
701
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
702
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
703
+
704
+ return dict(
705
+ prompt_input_ids=prompt_input_ids,
706
+ prompt_attention_mask=prompt_attention_mask,
707
+ input_ids=answer_input_ids,
708
+ attention_mask=answer_attention_mask,
709
+ )
710
+
711
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
712
+ """Tokenize a single row from a ORPO specific dataset.
713
+
714
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
715
+ chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
716
+ we truncate the chosen/rejected.
717
+
718
+ We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
719
+ of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
720
+ """
721
+ batch = {}
722
+ prompt = feature["prompt"]
723
+ chosen = feature["chosen"]
724
+ rejected = feature["rejected"]
725
+
726
+ if not self.is_encoder_decoder:
727
+ # Check issues below for more details
728
+ # 1. https://github.com/huggingface/trl/issues/907
729
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
730
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
731
+
732
+ if not isinstance(prompt, str):
733
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
734
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
735
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
736
+
737
+ if not isinstance(chosen, str):
738
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
739
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
740
+
741
+ if not isinstance(rejected, str):
742
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
743
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
744
+
745
+ # Last prompt token might get merged by tokenizer and
746
+ # it should not be included for generation if that happens
747
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
748
+
749
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
750
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
751
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
752
+
753
+ for k, v in prompt_tokens.items():
754
+ prompt_tokens[k] = v[:prompt_len_input_ids]
755
+
756
+ # Make sure prompts only have one different token at most an
757
+ # and length only differs by 1 at most
758
+ num_diff_tokens = sum(
759
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
760
+ )
761
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
762
+ if num_diff_tokens > 1 or num_diff_len > 1:
763
+ raise ValueError(
764
+ "Chosen and rejected prompt_input_ids might only differ on the "
765
+ "last token due to tokenizer merge ops."
766
+ )
767
+
768
+ # add BOS token to head of prompt. Avoid adding if it's already there
769
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
770
+ self.processing_class.bos_token_id,
771
+ prompt_len_input_ids,
772
+ prompt_tokens,
773
+ chosen_prompt_len_input_ids,
774
+ chosen_tokens,
775
+ rejected_prompt_len_input_ids,
776
+ rejected_tokens,
777
+ )
778
+
779
+ # add EOS token to end of answer. Avoid adding if it's already there
780
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
781
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
782
+ )
783
+
784
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
785
+
786
+ # if combined sequence is too long, truncate the prompt
787
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
788
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
789
+ if self.truncation_mode == "keep_start":
790
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
791
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
792
+ elif self.truncation_mode == "keep_end":
793
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
794
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
795
+ else:
796
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
797
+
798
+ # if that's still too long, truncate the response
799
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
800
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
801
+ for k in ["input_ids", "attention_mask"]:
802
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
803
+
804
+ # Create labels
805
+ chosen_sequence_tokens = {
806
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
807
+ }
808
+ rejected_sequence_tokens = {
809
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
810
+ }
811
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
812
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
813
+ self.label_pad_token_id
814
+ ] * len(chosen_tokens["prompt_input_ids"])
815
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
816
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
817
+ self.label_pad_token_id
818
+ ] * len(rejected_tokens["prompt_input_ids"])
819
+
820
+ for k, toks in {
821
+ "chosen_": chosen_sequence_tokens,
822
+ "rejected_": rejected_sequence_tokens,
823
+ "": prompt_tokens,
824
+ }.items():
825
+ for type_key, tokens in toks.items():
826
+ if type_key == "token_type_ids":
827
+ continue
828
+ batch[f"{k}{type_key}"] = tokens
829
+
830
+ else:
831
+ chosen_tokens = self.processing_class(
832
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
833
+ )
834
+ rejected_tokens = self.processing_class(
835
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
836
+ )
837
+ prompt_tokens = self.processing_class(
838
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
839
+ )
840
+
841
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
842
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
843
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
844
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
845
+
846
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
847
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
848
+ labels=torch.tensor(batch["rejected_labels"])
849
+ )
850
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
851
+ labels=torch.tensor(batch["chosen_labels"])
852
+ )
853
+
854
+ if is_torch_xla_available():
855
+ # Pad the sequences to global max_length to avoid TorchXLA recompilation
856
+ for k in batch:
857
+ if "labels" in k or self.is_encoder_decoder:
858
+ pad_value = self.label_pad_token_id
859
+ elif k.endswith("_input_ids"):
860
+ pad_value = self.padding_value
861
+ elif k.endswith("_attention_mask"):
862
+ pad_value = 0
863
+ batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
864
+ return batch
865
+
866
+ @staticmethod
867
+ def concatenated_inputs(
868
+ batch: dict[str, Union[list, torch.LongTensor]],
869
+ is_encoder_decoder: bool = False,
870
+ label_pad_token_id: int = -100,
871
+ padding_value: int = 0,
872
+ device: Optional[torch.device] = None,
873
+ ) -> dict[str, torch.LongTensor]:
874
+ """Concatenate the chosen and rejected inputs into a single tensor.
875
+
876
+ Args:
877
+ batch:
878
+ A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
879
+ of shape (batch_size, sequence_length).
880
+ is_encoder_decoder:
881
+ Whether the model is an encoder-decoder model.
882
+ label_pad_token_id:
883
+ The label pad token id.
884
+ padding_value:
885
+ The padding value to use for the concatenated inputs_ids.
886
+ device:
887
+ The device for the concatenated inputs.
888
+
889
+ Returns:
890
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
891
+ """
892
+ concatenated_batch = {}
893
+
894
+ if is_encoder_decoder:
895
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
896
+ else:
897
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
898
+
899
+ for k in batch:
900
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
901
+ if "labels" in k or is_encoder_decoder:
902
+ pad_value = label_pad_token_id
903
+ elif k.endswith("_input_ids"):
904
+ pad_value = padding_value
905
+ elif k.endswith("_attention_mask"):
906
+ pad_value = 0
907
+ concatenated_key = k.replace("chosen", "concatenated")
908
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
909
+ for k in batch:
910
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
911
+ if "labels" in k or is_encoder_decoder:
912
+ pad_value = label_pad_token_id
913
+ elif k.endswith("_input_ids"):
914
+ pad_value = padding_value
915
+ elif k.endswith("_attention_mask"):
916
+ pad_value = 0
917
+ concatenated_key = k.replace("rejected", "concatenated")
918
+ concatenated_batch[concatenated_key] = torch.cat(
919
+ (
920
+ concatenated_batch[concatenated_key],
921
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
922
+ ),
923
+ dim=0,
924
+ ).to(device=device)
925
+
926
+ if is_encoder_decoder:
927
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
928
+ concatenated_batch["concatenated_attention_mask"] = (
929
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
930
+ )
931
+
932
+ return concatenated_batch
933
+
934
+ def odds_ratio_loss(
935
+ self,
936
+ policy_chosen_logps: torch.FloatTensor,
937
+ policy_rejected_logps: torch.FloatTensor,
938
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
939
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
940
+
941
+ Args:
942
+ policy_chosen_logps:
943
+ Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
944
+ policy_rejected_logps:
945
+ Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
946
+
947
+ Returns:
948
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO
949
+ loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
950
+ the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the
951
+ rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes.
952
+ """
953
+
954
+ # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
955
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
956
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
957
+ )
958
+ ratio = F.logsigmoid(log_odds)
959
+ losses = self.beta * ratio
960
+
961
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
962
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
963
+
964
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
965
+
966
+ @staticmethod
967
+ def get_batch_logps(
968
+ logits: torch.FloatTensor,
969
+ labels: torch.LongTensor,
970
+ average_log_prob: bool = False,
971
+ label_pad_token_id: int = -100,
972
+ is_encoder_decoder: bool = False,
973
+ ) -> torch.FloatTensor:
974
+ """Compute the log probabilities of the given labels under the given logits.
975
+
976
+ Args:
977
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
978
+ labels:
979
+ Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
980
+ ignored. Shape: (batch_size, sequence_length)
981
+ average_log_prob:
982
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
983
+ log probabilities of the (non-masked) tokens.
984
+ label_pad_token_id: The label pad token id.
985
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
986
+
987
+ Returns:
988
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
989
+ given logits.
990
+ """
991
+ if logits.shape[:-1] != labels.shape:
992
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
993
+
994
+ if not is_encoder_decoder:
995
+ labels = labels[:, 1:].clone()
996
+ logits = logits[:, :-1, :]
997
+ loss_mask = labels != label_pad_token_id
998
+
999
+ # dummy token; we'll ignore the losses on these tokens later
1000
+ labels = torch.where(labels == label_pad_token_id, 0, labels)
1001
+
1002
+ per_token_logps = selective_log_softmax(logits, labels)
1003
+
1004
+ if average_log_prob:
1005
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1006
+ else:
1007
+ return (per_token_logps * loss_mask).sum(-1)
1008
+
1009
+ def concatenated_forward(
1010
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1011
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1012
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1013
+
1014
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1015
+ """
1016
+ concatenated_batch = self.concatenated_inputs(
1017
+ batch,
1018
+ is_encoder_decoder=self.is_encoder_decoder,
1019
+ label_pad_token_id=self.label_pad_token_id,
1020
+ padding_value=self.padding_value,
1021
+ device=self.accelerator.device,
1022
+ )
1023
+ len_chosen = batch["chosen_labels"].shape[0]
1024
+
1025
+ model_kwargs = (
1026
+ {
1027
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1028
+ }
1029
+ if self.is_encoder_decoder
1030
+ else {}
1031
+ )
1032
+
1033
+ if self.aux_loss_enabled:
1034
+ model_kwargs["output_router_logits"] = True
1035
+
1036
+ outputs = model(
1037
+ concatenated_batch["concatenated_input_ids"],
1038
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1039
+ use_cache=False,
1040
+ **model_kwargs,
1041
+ )
1042
+ all_logits = outputs.logits
1043
+
1044
+ def cross_entropy_loss(logits, labels):
1045
+ if not self.is_encoder_decoder:
1046
+ # Shift so that tokens < n predict n
1047
+ logits = logits[..., :-1, :].contiguous()
1048
+ labels = labels[..., 1:].contiguous()
1049
+ # Flatten the tokens
1050
+ loss_fct = nn.CrossEntropyLoss()
1051
+ logits = logits.view(-1, logits.shape[-1])
1052
+ labels = labels.view(-1)
1053
+ # Enable model parallelism
1054
+ labels = labels.to(logits.device)
1055
+ loss = loss_fct(logits, labels)
1056
+ return loss
1057
+
1058
+ if self.is_encoder_decoder:
1059
+ labels = concatenated_batch["concatenated_labels"].clone()
1060
+ else:
1061
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1062
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
1063
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1064
+ # orpo chosen nll loss is computed over the full prompt and response
1065
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1066
+
1067
+ all_logps = self.get_batch_logps(
1068
+ all_logits,
1069
+ concatenated_batch["concatenated_labels"],
1070
+ average_log_prob=True,
1071
+ is_encoder_decoder=self.is_encoder_decoder,
1072
+ label_pad_token_id=self.label_pad_token_id,
1073
+ )
1074
+
1075
+ chosen_logps = all_logps[:len_chosen]
1076
+ rejected_logps = all_logps[len_chosen:]
1077
+
1078
+ if not self.is_encoder_decoder:
1079
+ chosen_logits = all_logits[:len_chosen, :-1, :]
1080
+ rejected_logits = all_logits[len_chosen:, :-1, :]
1081
+ else:
1082
+ chosen_logits = all_logits[:len_chosen]
1083
+ rejected_logits = all_logits[len_chosen:]
1084
+
1085
+ if self.aux_loss_enabled:
1086
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1087
+
1088
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1089
+
1090
+ def get_batch_loss_metrics(
1091
+ self,
1092
+ model,
1093
+ batch: dict[str, Union[list, torch.LongTensor]],
1094
+ train_eval: Literal["train", "eval"] = "train",
1095
+ ):
1096
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1097
+ metrics = {}
1098
+
1099
+ forward_output = self.concatenated_forward(model, batch)
1100
+ (
1101
+ policy_chosen_logps,
1102
+ policy_rejected_logps,
1103
+ policy_chosen_logits,
1104
+ policy_rejected_logits,
1105
+ policy_nll_loss,
1106
+ ) = forward_output[:5]
1107
+ if self.aux_loss_enabled:
1108
+ aux_loss = forward_output[5]
1109
+
1110
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1111
+ policy_chosen_logps, policy_rejected_logps
1112
+ )
1113
+ # full ORPO loss
1114
+ loss = policy_nll_loss - losses.mean()
1115
+
1116
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1117
+
1118
+ prefix = "eval_" if train_eval == "eval" else ""
1119
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1120
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1121
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1122
+ metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1123
+ chosen_rewards - rejected_rewards
1124
+ ).mean()
1125
+ metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1126
+ metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1127
+ metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
1128
+ policy_rejected_logits.detach().mean()
1129
+ ).mean()
1130
+ metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
1131
+ policy_chosen_logits.detach().mean()
1132
+ ).mean()
1133
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1134
+ metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
1135
+ metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
1136
+ if is_torch_xla_available():
1137
+ xm.mark_step() # needed because .item() calls
1138
+ for k, v in metrics.items():
1139
+ metrics[k] = v.item()
1140
+ if self.aux_loss_enabled:
1141
+ loss += self.aux_loss_coef * aux_loss
1142
+
1143
+ return loss, metrics
1144
+
1145
+ def compute_loss(
1146
+ self,
1147
+ model: Union[PreTrainedModel, nn.Module],
1148
+ inputs: dict[str, Union[torch.Tensor, Any]],
1149
+ return_outputs=False,
1150
+ num_items_in_batch=None,
1151
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1152
+ compute_loss_context_manager = (
1153
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1154
+ )
1155
+
1156
+ with compute_loss_context_manager:
1157
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1158
+
1159
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1160
+ loss = loss.to(self.args.device)
1161
+
1162
+ # force log the metrics
1163
+ self.store_metrics(metrics, train_eval="train")
1164
+
1165
+ if return_outputs:
1166
+ return (loss, metrics)
1167
+ return loss
1168
+
1169
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1170
+ """Generate samples from the model and reference model for the given batch of inputs."""
1171
+
1172
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1173
+ # the torch amp context manager as some hidden states are silently casted to full precision.
1174
+ generate_context_manager = (
1175
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1176
+ )
1177
+
1178
+ with generate_context_manager:
1179
+ policy_output = model.generate(
1180
+ input_ids=batch["prompt_input_ids"],
1181
+ attention_mask=batch["prompt_attention_mask"],
1182
+ max_length=self.max_length,
1183
+ do_sample=True,
1184
+ pad_token_id=self.processing_class.pad_token_id,
1185
+ )
1186
+
1187
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1188
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1189
+
1190
+ return policy_output_decoded
1191
+
1192
+ def prediction_step(
1193
+ self,
1194
+ model: Union[PreTrainedModel, nn.Module],
1195
+ inputs: dict[str, Union[torch.Tensor, Any]],
1196
+ prediction_loss_only: bool,
1197
+ ignore_keys: Optional[list[str]] = None,
1198
+ ):
1199
+ if not self.use_dpo_data_collator:
1200
+ warnings.warn(
1201
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1202
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1203
+ )
1204
+ if ignore_keys is None:
1205
+ if hasattr(model, "config"):
1206
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1207
+ else:
1208
+ ignore_keys = []
1209
+
1210
+ prediction_context_manager = (
1211
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1212
+ )
1213
+
1214
+ with torch.no_grad(), prediction_context_manager:
1215
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1216
+
1217
+ # force log the metrics
1218
+ self.store_metrics(metrics, train_eval="eval")
1219
+
1220
+ if prediction_loss_only:
1221
+ return (loss.detach(), None, None)
1222
+
1223
+ # logits for the chosen and rejected samples from model
1224
+ logits_dict = {
1225
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1226
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1227
+ }
1228
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1229
+ logits = torch.tensor(logits, device=self.accelerator.device)
1230
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1231
+
1232
+ return (loss.detach(), logits, labels)
1233
+
1234
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1235
+ for key, value in metrics.items():
1236
+ self._stored_metrics[train_eval][key].append(value)
1237
+
1238
+ def evaluation_loop(
1239
+ self,
1240
+ dataloader: DataLoader,
1241
+ description: str,
1242
+ prediction_loss_only: Optional[bool] = None,
1243
+ ignore_keys: Optional[list[str]] = None,
1244
+ metric_key_prefix: str = "eval",
1245
+ ) -> EvalLoopOutput:
1246
+ """
1247
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
1248
+ `Trainer.evaluate()` and `Trainer.predict()`.
1249
+
1250
+ Works both with or without labels.
1251
+ """
1252
+
1253
+ # Sample and save to game log if requested (for one batch to save time)
1254
+ if self.generate_during_eval:
1255
+ # Generate random indices within the range of the total number of samples
1256
+ num_samples = len(dataloader.dataset)
1257
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1258
+
1259
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1260
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1261
+ random_batch = self.data_collator(random_batch_dataset)
1262
+ random_batch = self._prepare_inputs(random_batch)
1263
+
1264
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1265
+
1266
+ table = pd.DataFrame(
1267
+ columns=["Prompt", "Policy"],
1268
+ data=[
1269
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1270
+ ],
1271
+ )
1272
+ if "wandb" in self.args.report_to:
1273
+ wandb.log({"game_log": wandb.Table(data=table)})
1274
+
1275
+ if "comet_ml" in self.args.report_to:
1276
+ log_table_to_comet_experiment(
1277
+ name="game_log.csv",
1278
+ table=table,
1279
+ )
1280
+
1281
+ # Base evaluation
1282
+ initial_output = super().evaluation_loop(
1283
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1284
+ )
1285
+
1286
+ return initial_output
1287
+
1288
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1289
+ """
1290
+ Log `logs` on the various objects watching training, including stored metrics.
1291
+
1292
+ Args:
1293
+ logs (`dict[str, float]`):
1294
+ The values to log.
1295
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1296
+ Start time of the training.
1297
+ """
1298
+ # logs either has 'loss' or 'eval_loss'
1299
+ train_eval = "train" if "loss" in logs else "eval"
1300
+ # Add averaged stored metrics to logs
1301
+ for key, metrics in self._stored_metrics[train_eval].items():
1302
+ logs[key] = torch.tensor(metrics).mean().item()
1303
+ del self._stored_metrics[train_eval]
1304
+ return super().log(logs, start_time)
1305
+
1306
+ def _shift_right(self, input_ids):
1307
+ if self.decoder_start_token_id is None:
1308
+ raise ValueError(
1309
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1310
+ )
1311
+
1312
+ # shift inputs to the right
1313
+ if is_torch_fx_proxy(input_ids):
1314
+ # Item assignment is not supported natively for proxies.
1315
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1316
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1317
+ else:
1318
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1319
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1320
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1321
+
1322
+ if self.pad_token_id is None:
1323
+ raise ValueError("model.config.pad_token_id has to be defined.")
1324
+ # replace possible -100 values in labels by `pad_token_id`
1325
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1326
+
1327
+ return shifted_input_ids
1328
+
1329
+ # Ensure the model card is saved along with the checkpoint
1330
+ def _save_checkpoint(self, model, trial):
1331
+ if self.args.hub_model_id is None:
1332
+ model_name = Path(self.args.output_dir).name
1333
+ else:
1334
+ model_name = self.args.hub_model_id.split("/")[-1]
1335
+ self.create_model_card(model_name=model_name)
1336
+ super()._save_checkpoint(model, trial)
1337
+
1338
+ def create_model_card(
1339
+ self,
1340
+ model_name: Optional[str] = None,
1341
+ dataset_name: Optional[str] = None,
1342
+ tags: Union[str, list[str], None] = None,
1343
+ ):
1344
+ """
1345
+ Creates a draft of a model card using the information available to the `Trainer`.
1346
+
1347
+ Args:
1348
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1349
+ Name of the model.
1350
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1351
+ Name of the dataset used for training.
1352
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1353
+ Tags to be associated with the model card.
1354
+ """
1355
+ if not self.is_world_process_zero():
1356
+ return
1357
+
1358
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1359
+ base_model = self.model.config._name_or_path
1360
+ else:
1361
+ base_model = None
1362
+
1363
+ # normalize `tags` to a mutable set
1364
+ if tags is None:
1365
+ tags = set()
1366
+ elif isinstance(tags, str):
1367
+ tags = {tags}
1368
+ else:
1369
+ tags = set(tags)
1370
+
1371
+ if hasattr(self.model.config, "unsloth_version"):
1372
+ tags.add("unsloth")
1373
+
1374
+ tags.update(self._tag_names)
1375
+
1376
+ citation = textwrap.dedent("""\
1377
+ @article{hong2024orpo,
1378
+ title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1379
+ author = {Jiwoo Hong and Noah Lee and James Thorne},
1380
+ year = 2024,
1381
+ eprint = {arXiv:2403.07691}
1382
+ }""")
1383
+
1384
+ model_card = generate_model_card(
1385
+ base_model=base_model,
1386
+ model_name=model_name,
1387
+ hub_model_id=self.hub_model_id,
1388
+ dataset_name=dataset_name,
1389
+ tags=tags,
1390
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1391
+ comet_url=get_comet_experiment_url(),
1392
+ trainer_name="ORPO",
1393
+ trainer_citation=citation,
1394
+ paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1395
+ paper_id="2403.07691",
1396
+ )
1397
+
1398
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1399
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1400
+ """
1401
+
1402
+ Initialize ORPOTrainer.
1403
+
1404
+ Args:
1405
+ model (`transformers.PreTrainedModel`):
1406
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1407
+ args (`ORPOConfig`):
1408
+ The ORPO config arguments to use for training.
1409
+ data_collator (`transformers.DataCollator`):
1410
+ The data collator to use for training. If None is specified, the default data collator
1411
+ (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
1412
+ sequences in the batch, given a dataset of paired sequences.
1413
+ train_dataset (`datasets.Dataset`):
1414
+ The dataset to use for training.
1415
+ eval_dataset (`datasets.Dataset`):
1416
+ The dataset to use for evaluation.
1417
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
1418
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1419
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1420
+ reuse the fine-tuned model.
1421
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1422
+ The model initializer to use for training. If None is specified, the default model initializer will be
1423
+ used.
1424
+ callbacks (`list[transformers.TrainerCallback]`):
1425
+ The callbacks to use for training.
1426
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1427
+ The optimizer and scheduler to use for training.
1428
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1429
+ The function to use to preprocess the logits before computing the metrics.
1430
+ peft_config (`dict`, defaults to `None`):
1431
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
1432
+ a PEFT model.
1433
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1434
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1435
+ metric values.
1436
+
1437
+ """
1438
+ def __init__(
1439
+ self,
1440
+ model = None,
1441
+ args = None,
1442
+ data_collator = None,
1443
+ train_dataset = None,
1444
+ eval_dataset = None,
1445
+ processing_class = None,
1446
+ model_init = None,
1447
+ callbacks = None,
1448
+ preprocess_logits_for_metrics = None,
1449
+ peft_config = None,
1450
+ compute_metrics = None,
1451
+ **kwargs
1452
+ ):
1453
+ if args is None: args = UnslothORPOConfig()
1454
+ use_bf16 = getattr(args, 'bf16', False)
1455
+ if type(use_bf16) is not bool: use_bf16 = False
1456
+ use_fp16 = getattr(args, 'fp16', False)
1457
+ if type(use_fp16) is not bool: use_fp16 = False
1458
+ force_float32 = False
1459
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1460
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1461
+ force_float32 = True
1462
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1463
+ dtype = getattr(model.config, 'torch_dtype', None)
1464
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1465
+ from unsloth_zoo.utils import _get_dtype
1466
+ dtype = _get_dtype(dtype)
1467
+ float16 = dtype == torch.float16
1468
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1469
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1470
+ if force_float32:
1471
+ args.fp16 = False
1472
+ args.bf16 = False
1473
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1474
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1475
+ args.fp16 = float16
1476
+ args.bf16 = not float16
1477
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1478
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1479
+ args.eval_strategy = 'steps'
1480
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1481
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1482
+ if ga_steps is not None and ga_steps > 1:
1483
+ from transformers import __version__ as transformers_version
1484
+ if Version(transformers_version) <= Version('4.45.2'):
1485
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1486
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1487
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1488
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1489
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1490
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1491
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1492
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1493
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1494
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1495
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1496
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1497
+ if force_float32:
1498
+ args.bf16_full_eval = False
1499
+ args.fp16_full_eval = False
1500
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1501
+ args.bf16_full_eval = True
1502
+ args.fp16_full_eval = False
1503
+ elif not bf16_full_eval and not fp16_full_eval:
1504
+ args.bf16_full_eval = args.bf16
1505
+ args.fp16_full_eval = args.fp16
1506
+ _output_logits = False
1507
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1508
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1509
+ if _output_logits:
1510
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1511
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1512
+ pass
1513
+ else:
1514
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1515
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1516
+ if args_max_seq_length is None and model_max_seq_length is not None:
1517
+ max_seq_length = model.max_seq_length
1518
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1519
+ if model is not None and hasattr(model, 'for_training'):
1520
+ model.for_training()
1521
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1522
+ if 'processing_class' in locals():
1523
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1524
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1525
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1526
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1527
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1528
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1529
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1530
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1531
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1532
+ else:
1533
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1534
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1535
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1536
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1537
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1538
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1539
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1540
+ else:
1541
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1542
+ other_metrics = []
1543
+
1544
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1545
+ PatchRLStatistics('orpo_trainer', other_metrics)
1546
+
1547
+ super().__init__(
1548
+ model = model,
1549
+ args = args,
1550
+ data_collator = data_collator,
1551
+ train_dataset = train_dataset,
1552
+ eval_dataset = eval_dataset,
1553
+ processing_class = processing_class,
1554
+ model_init = model_init,
1555
+ callbacks = callbacks,
1556
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1557
+ peft_config = peft_config,
1558
+ compute_metrics = compute_metrics,**kwargs)
1559
+ if hasattr(self, 'neftune_hook_handle'):
1560
+ self.neftune_hook_handle.remove()
1561
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1562
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1563
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1564
+ pass
1565
+ if hasattr(self, 'accelerator'):
1566
+ scaler = self.accelerator.scaler
1567
+ current_model = model
1568
+ while hasattr(current_model, 'model'):
1569
+ current_model.accelerator_scaler = scaler
1570
+ current_model = current_model.model
1571
+ current_model.accelerator_scaler = scaler
1572
+ pass
1573
+
1574
+ pass
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.online_dpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, os, prepare_deepspeed, seed_worker, textwrap, torch, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ def vLLMSamplingParams(**kwargs):
52
+ from vllm import SamplingParams
53
+ sampling_params = SamplingParams(**kwargs)
54
+ sampling_params._set_kwargs = kwargs
55
+ return sampling_params
56
+ @dataclass
57
+ class UnslothOnlineDPOConfig(OnlineDPOConfig):
58
+ """
59
+
60
+ Configuration class for the [`OnlineDPOTrainer`].
61
+
62
+ This class includes only the parameters that are specific to Online DPO training. For a full list of training
63
+ arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
64
+ class may differ from those in [`~transformers.TrainingArguments`].
65
+
66
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
67
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
68
+ command line.
69
+
70
+ Parameters:
71
+ reward_model_path (`str` or `None`, *optional*, defaults to `None`):
72
+ Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
73
+ judge (`str` or `None`, *optional*, defaults to `None`):
74
+ Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
75
+ max_new_tokens (`int`, *optional*, defaults to `64`):
76
+ Maximum number of tokens to generate per completion.
77
+ max_length (`int`, *optional*, defaults to `256`):
78
+ Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
79
+ sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
80
+ possible.
81
+ temperature (`float`, *optional*, defaults to `0.9`):
82
+ Temperature for sampling. The higher the temperature, the more random the completions.
83
+ missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
84
+ Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to
85
+ generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
86
+ value.
87
+ beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
88
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
89
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
90
+ the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
91
+ selected for each new epoch and the last β is used for the rest of the epochs.
92
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
93
+ Type of loss to use. Possible values are:
94
+
95
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
96
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
97
+
98
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
99
+ Number of processes to use for processing the dataset.
100
+ disable_dropout (`bool`, *optional*, defaults to `True`):
101
+ Whether to disable dropout in the model and reference model.
102
+ use_vllm (`bool`, *optional*, defaults to `False`):
103
+ Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
104
+ vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
105
+ Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
106
+ the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
107
+ implementation.
108
+ gpu_memory_utilization (`float`, *optional*, defaults to `0.55`):
109
+ The vLLM memory utilization. The default value is 0.55.
110
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
111
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
112
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
113
+ capacity of a single GPU, albeit at the cost of slower generation.
114
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
115
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
116
+ string.
117
+
118
+ """
119
+ vllm_sampling_params: Optional[Any] = field(
120
+ default = None,
121
+ metadata = {'help': 'vLLM SamplingParams'},
122
+ )
123
+ unsloth_num_chunks : Optional[int] = field(
124
+ default = -1,
125
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
126
+ )
127
+ max_seq_length : Optional[int] = field(
128
+ default = None,
129
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
130
+ )
131
+ def __init__(
132
+ self,
133
+ output_dir = None,
134
+ overwrite_output_dir = None,
135
+ do_train = False,
136
+ do_eval = False,
137
+ do_predict = False,
138
+ eval_strategy = 'no',
139
+ prediction_loss_only = False,
140
+ per_device_train_batch_size = 4,
141
+ per_device_eval_batch_size = 4,
142
+ per_gpu_train_batch_size = None,
143
+ per_gpu_eval_batch_size = None,
144
+ gradient_accumulation_steps = 2,
145
+ eval_accumulation_steps = 2,
146
+ eval_delay = 0,
147
+ torch_empty_cache_steps = 250,
148
+ learning_rate = 5e-05,
149
+ weight_decay = 0.01,
150
+ adam_beta1 = 0.9,
151
+ adam_beta2 = 0.999,
152
+ adam_epsilon = 1e-08,
153
+ max_grad_norm = 1.0,
154
+ num_train_epochs = 3.0,
155
+ max_steps = -1,
156
+ lr_scheduler_type = 'linear',
157
+ warmup_ratio = 0.1,
158
+ warmup_steps = 0,
159
+ log_level = 'passive',
160
+ log_level_replica = 'warning',
161
+ log_on_each_node = True,
162
+ logging_dir = None,
163
+ logging_strategy = 'steps',
164
+ logging_first_step = False,
165
+ logging_steps = 1,
166
+ logging_nan_inf_filter = False,
167
+ save_strategy = 'steps',
168
+ save_steps = 500,
169
+ save_total_limit = None,
170
+ save_safetensors = True,
171
+ save_on_each_node = False,
172
+ save_only_model = False,
173
+ restore_callback_states_from_checkpoint = False,
174
+ no_cuda = False,
175
+ use_cpu = False,
176
+ use_mps_device = False,
177
+ seed = 3407,
178
+ data_seed = 3407,
179
+ jit_mode_eval = False,
180
+ use_ipex = False,
181
+ bf16 = False,
182
+ fp16 = False,
183
+ fp16_opt_level = 'O1',
184
+ half_precision_backend = 'auto',
185
+ bf16_full_eval = False,
186
+ fp16_full_eval = False,
187
+ tf32 = None,
188
+ local_rank = -1,
189
+ ddp_backend = None,
190
+ tpu_num_cores = None,
191
+ tpu_metrics_debug = False,
192
+ debug = '',
193
+ dataloader_drop_last = False,
194
+ eval_steps = None,
195
+ dataloader_num_workers = 0,
196
+ dataloader_prefetch_factor = None,
197
+ past_index = -1,
198
+ run_name = None,
199
+ disable_tqdm = None,
200
+ remove_unused_columns = True,
201
+ label_names = None,
202
+ load_best_model_at_end = False,
203
+ metric_for_best_model = None,
204
+ greater_is_better = None,
205
+ ignore_data_skip = False,
206
+ fsdp = '',
207
+ fsdp_min_num_params = 0,
208
+ fsdp_config = None,
209
+ fsdp_transformer_layer_cls_to_wrap = None,
210
+ accelerator_config = None,
211
+ deepspeed = None,
212
+ label_smoothing_factor = 0.0,
213
+ optim = 'adamw_8bit',
214
+ optim_args = None,
215
+ adafactor = False,
216
+ group_by_length = False,
217
+ length_column_name = 'length',
218
+ report_to = None,
219
+ ddp_find_unused_parameters = None,
220
+ ddp_bucket_cap_mb = None,
221
+ ddp_broadcast_buffers = None,
222
+ dataloader_pin_memory = True,
223
+ dataloader_persistent_workers = False,
224
+ skip_memory_metrics = True,
225
+ use_legacy_prediction_loop = False,
226
+ push_to_hub = False,
227
+ resume_from_checkpoint = None,
228
+ hub_model_id = None,
229
+ hub_strategy = 'every_save',
230
+ hub_token = None,
231
+ hub_private_repo = None,
232
+ hub_always_push = False,
233
+ hub_revision = None,
234
+ gradient_checkpointing = False,
235
+ gradient_checkpointing_kwargs = None,
236
+ include_inputs_for_metrics = False,
237
+ eval_do_concat_batches = True,
238
+ fp16_backend = 'auto',
239
+ push_to_hub_model_id = None,
240
+ push_to_hub_organization = None,
241
+ push_to_hub_token = None,
242
+ mp_parameters = '',
243
+ auto_find_batch_size = True,
244
+ full_determinism = False,
245
+ torchdynamo = None,
246
+ ray_scope = 'last',
247
+ ddp_timeout = 1800,
248
+ torch_compile = False,
249
+ torch_compile_backend = None,
250
+ torch_compile_mode = None,
251
+ include_tokens_per_second = False,
252
+ include_num_input_tokens_seen = False,
253
+ neftune_noise_alpha = None,
254
+ optim_target_modules = None,
255
+ batch_eval_metrics = False,
256
+ eval_on_start = False,
257
+ use_liger_kernel = False,
258
+ liger_kernel_config = None,
259
+ eval_use_gather_object = False,
260
+ average_tokens_across_devices = True,
261
+ reward_model_path = None,
262
+ judge = None,
263
+ max_new_tokens = 64,
264
+ max_length = 512,
265
+ temperature = 0.9,
266
+ missing_eos_penalty = None,
267
+ loss_type = 'sigmoid',
268
+ dataset_num_proc = None,
269
+ disable_dropout = True,
270
+ use_vllm = False,
271
+ vllm_model_impl = 'vllm',
272
+ gpu_memory_utilization = 0.55,
273
+ ds3_gather_for_generation = True,
274
+ model_init_kwargs = None,
275
+ vllm_sampling_params = None,
276
+ unsloth_num_chunks = -1,
277
+ max_seq_length = None,
278
+ **kwargs,
279
+ ):
280
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
281
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
282
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
283
+ output_dir = 'unsloth_training_checkpoints'
284
+ save_strategy = 'no'
285
+ if dataset_num_proc is None:
286
+ from multiprocessing import cpu_count
287
+ dataset_num_proc = min(cpu_count()*2, 2)
288
+ if temperature <= 0:
289
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
290
+ elif temperature >= 10:
291
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
292
+
293
+
294
+ super().__init__(
295
+ output_dir = output_dir,
296
+ overwrite_output_dir = overwrite_output_dir,
297
+ do_train = do_train,
298
+ do_eval = do_eval,
299
+ do_predict = do_predict,
300
+ eval_strategy = eval_strategy,
301
+ prediction_loss_only = prediction_loss_only,
302
+ per_device_train_batch_size = per_device_train_batch_size,
303
+ per_device_eval_batch_size = per_device_eval_batch_size,
304
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
305
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
306
+ gradient_accumulation_steps = gradient_accumulation_steps,
307
+ eval_accumulation_steps = eval_accumulation_steps,
308
+ eval_delay = eval_delay,
309
+ torch_empty_cache_steps = torch_empty_cache_steps,
310
+ learning_rate = learning_rate,
311
+ weight_decay = weight_decay,
312
+ adam_beta1 = adam_beta1,
313
+ adam_beta2 = adam_beta2,
314
+ adam_epsilon = adam_epsilon,
315
+ max_grad_norm = max_grad_norm,
316
+ num_train_epochs = num_train_epochs,
317
+ max_steps = max_steps,
318
+ lr_scheduler_type = lr_scheduler_type,
319
+ warmup_ratio = warmup_ratio,
320
+ warmup_steps = warmup_steps,
321
+ log_level = log_level,
322
+ log_level_replica = log_level_replica,
323
+ log_on_each_node = log_on_each_node,
324
+ logging_dir = logging_dir,
325
+ logging_strategy = logging_strategy,
326
+ logging_first_step = logging_first_step,
327
+ logging_steps = logging_steps,
328
+ logging_nan_inf_filter = logging_nan_inf_filter,
329
+ save_strategy = save_strategy,
330
+ save_steps = save_steps,
331
+ save_total_limit = save_total_limit,
332
+ save_safetensors = save_safetensors,
333
+ save_on_each_node = save_on_each_node,
334
+ save_only_model = save_only_model,
335
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
336
+ no_cuda = no_cuda,
337
+ use_cpu = use_cpu,
338
+ use_mps_device = use_mps_device,
339
+ seed = seed,
340
+ data_seed = data_seed,
341
+ jit_mode_eval = jit_mode_eval,
342
+ use_ipex = use_ipex,
343
+ bf16 = bf16,
344
+ fp16 = fp16,
345
+ fp16_opt_level = fp16_opt_level,
346
+ half_precision_backend = half_precision_backend,
347
+ bf16_full_eval = bf16_full_eval,
348
+ fp16_full_eval = fp16_full_eval,
349
+ tf32 = tf32,
350
+ local_rank = local_rank,
351
+ ddp_backend = ddp_backend,
352
+ tpu_num_cores = tpu_num_cores,
353
+ tpu_metrics_debug = tpu_metrics_debug,
354
+ debug = debug,
355
+ dataloader_drop_last = dataloader_drop_last,
356
+ eval_steps = eval_steps,
357
+ dataloader_num_workers = dataloader_num_workers,
358
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
359
+ past_index = past_index,
360
+ run_name = run_name,
361
+ disable_tqdm = disable_tqdm,
362
+ remove_unused_columns = remove_unused_columns,
363
+ label_names = label_names,
364
+ load_best_model_at_end = load_best_model_at_end,
365
+ metric_for_best_model = metric_for_best_model,
366
+ greater_is_better = greater_is_better,
367
+ ignore_data_skip = ignore_data_skip,
368
+ fsdp = fsdp,
369
+ fsdp_min_num_params = fsdp_min_num_params,
370
+ fsdp_config = fsdp_config,
371
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
372
+ accelerator_config = accelerator_config,
373
+ deepspeed = deepspeed,
374
+ label_smoothing_factor = label_smoothing_factor,
375
+ optim = optim,
376
+ optim_args = optim_args,
377
+ adafactor = adafactor,
378
+ group_by_length = group_by_length,
379
+ length_column_name = length_column_name,
380
+ report_to = report_to,
381
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
382
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
383
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
384
+ dataloader_pin_memory = dataloader_pin_memory,
385
+ dataloader_persistent_workers = dataloader_persistent_workers,
386
+ skip_memory_metrics = skip_memory_metrics,
387
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
388
+ push_to_hub = push_to_hub,
389
+ resume_from_checkpoint = resume_from_checkpoint,
390
+ hub_model_id = hub_model_id,
391
+ hub_strategy = hub_strategy,
392
+ hub_token = hub_token,
393
+ hub_private_repo = hub_private_repo,
394
+ hub_always_push = hub_always_push,
395
+ hub_revision = hub_revision,
396
+ gradient_checkpointing = gradient_checkpointing,
397
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
398
+ include_inputs_for_metrics = include_inputs_for_metrics,
399
+ eval_do_concat_batches = eval_do_concat_batches,
400
+ fp16_backend = fp16_backend,
401
+ push_to_hub_model_id = push_to_hub_model_id,
402
+ push_to_hub_organization = push_to_hub_organization,
403
+ push_to_hub_token = push_to_hub_token,
404
+ mp_parameters = mp_parameters,
405
+ auto_find_batch_size = auto_find_batch_size,
406
+ full_determinism = full_determinism,
407
+ torchdynamo = torchdynamo,
408
+ ray_scope = ray_scope,
409
+ ddp_timeout = ddp_timeout,
410
+ torch_compile = torch_compile,
411
+ torch_compile_backend = torch_compile_backend,
412
+ torch_compile_mode = torch_compile_mode,
413
+ include_tokens_per_second = include_tokens_per_second,
414
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
415
+ neftune_noise_alpha = neftune_noise_alpha,
416
+ optim_target_modules = optim_target_modules,
417
+ batch_eval_metrics = batch_eval_metrics,
418
+ eval_on_start = eval_on_start,
419
+ use_liger_kernel = use_liger_kernel,
420
+ liger_kernel_config = liger_kernel_config,
421
+ eval_use_gather_object = eval_use_gather_object,
422
+ average_tokens_across_devices = average_tokens_across_devices,
423
+ reward_model_path = reward_model_path,
424
+ judge = judge,
425
+ max_new_tokens = max_new_tokens,
426
+ max_length = max_length,
427
+ temperature = temperature,
428
+ missing_eos_penalty = missing_eos_penalty,
429
+ loss_type = loss_type,
430
+ dataset_num_proc = dataset_num_proc,
431
+ disable_dropout = disable_dropout,
432
+ use_vllm = use_vllm,
433
+ vllm_model_impl = vllm_model_impl,
434
+ gpu_memory_utilization = gpu_memory_utilization,
435
+ ds3_gather_for_generation = ds3_gather_for_generation,
436
+ model_init_kwargs = model_init_kwargs,**kwargs)
437
+ self.vllm_sampling_params = vllm_sampling_params
438
+ self.unsloth_num_chunks = unsloth_num_chunks
439
+ self.max_seq_length = max_seq_length
440
+ pass
441
+
442
+ class _UnslothOnlineDPOTrainer(Trainer):
443
+ r""""""
444
+
445
+ _tag_names = ["trl", "online-dpo"]
446
+
447
+ def __init__(
448
+ self,
449
+ model: Union[PreTrainedModel, nn.Module, str],
450
+ ref_model: Union[PreTrainedModel, nn.Module, None] = None,
451
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
452
+ judge: Optional[BasePairwiseJudge] = None,
453
+ args: Optional[OnlineDPOConfig] = None,
454
+ data_collator: Optional[DataCollator] = None,
455
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
456
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
457
+ processing_class: Optional[
458
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
459
+ ] = None,
460
+ reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
461
+ peft_config: Optional[dict] = None,
462
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
463
+ callbacks: Optional[list[TrainerCallback]] = None,
464
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
465
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
466
+ ) -> None:
467
+
468
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
469
+ if (getattr(args, 'use_vllm', False) == False):
470
+ args.use_vllm = True
471
+ if ref_model is model:
472
+ raise ValueError(
473
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
474
+ "same as `model`, either omit the `ref_model` argument or pass `None`."
475
+ )
476
+
477
+ self.ref_model = ref_model
478
+
479
+ if reward_model is not None and judge is not None:
480
+ warnings.warn(
481
+ "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
482
+ "Ignoring `judge` and using `reward_model`.",
483
+ UserWarning,
484
+ )
485
+ judge = None
486
+ elif reward_model is None and judge is None:
487
+ raise ValueError("Either `reward_model` or `judge` must be provided.")
488
+
489
+ self.reward_model = reward_model
490
+ self.reward_processing_class = reward_processing_class
491
+ self.judge = judge
492
+
493
+ if args.missing_eos_penalty is not None and judge is not None:
494
+ raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
495
+
496
+ if args is None:
497
+ raise ValueError("`args` must be provided.")
498
+
499
+ # Check that the processing_class is provided
500
+ if processing_class is None:
501
+ raise ValueError("`processing_class` must be provided.")
502
+
503
+ model_init_kwargs = args.model_init_kwargs or {}
504
+ if isinstance(model, str):
505
+ model_id = model
506
+
507
+ # Handle torch_dtype in model_init_kwargs
508
+ torch_dtype = model_init_kwargs.get("torch_dtype")
509
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
510
+ pass
511
+ elif isinstance(torch_dtype, str):
512
+ torch_dtype = getattr(torch, torch_dtype)
513
+ model_init_kwargs["torch_dtype"] = torch_dtype
514
+ else:
515
+ raise ValueError(
516
+ "Invalid `torch_dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string "
517
+ f"representing a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
518
+ )
519
+
520
+ model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
521
+ else:
522
+ if args.model_init_kwargs is not None:
523
+ raise ValueError(
524
+ "You passed `model_init_kwargs` to the `OnlineDPOConfig`, but your model is already instantiated. "
525
+ "This argument can only be used when the `model` argument is a string."
526
+ )
527
+ self.is_encoder_decoder = model.config.is_encoder_decoder
528
+
529
+ # Convert to PEFT model if peft_config is provided
530
+ if False:
531
+ # Check if PEFT is available
532
+ if not is_peft_available():
533
+ raise ImportError(
534
+ "PEFT is not available and passed `peft_config`. Please install PEFT with "
535
+ "`pip install peft` to use it."
536
+ )
537
+
538
+ # If the model is already a PeftModel, we need to merge and unload it.
539
+ # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
540
+ if isinstance(model, PeftModel):
541
+ model = model.merge_and_unload()
542
+
543
+ # Get peft model with the given config
544
+ model = model
545
+
546
+ # Disable dropout in the model and reference model
547
+ if args.disable_dropout:
548
+ disable_dropout_in_model(model)
549
+ if self.ref_model is not None:
550
+ disable_dropout_in_model(self.ref_model)
551
+
552
+ # Handle the ref_model
553
+ # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
554
+ # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
555
+ # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
556
+ if ref_model is None: # No ref model provided, the most common case
557
+ if False:
558
+ self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
559
+ else:
560
+ self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
561
+ else: # rare case, the user provided a ref model
562
+ self.ref_model = ref_model
563
+ self.ref_model.eval()
564
+
565
+ # Disable the gradient and set the reward model in eval mode
566
+ if self.reward_model is not None:
567
+ self.reward_model.eval()
568
+
569
+ # Define the collator is not provided
570
+ if data_collator is None:
571
+ data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
572
+
573
+ self.max_length = args.max_length
574
+
575
+ self.stats = {
576
+ "objective/kl": [],
577
+ "objective/entropy": [],
578
+ "objective/non_score_reward": [],
579
+ "rewards/chosen": [],
580
+ "rewards/rejected": [],
581
+ "rewards/accuracies": [],
582
+ "rewards/margins": [],
583
+ "logps/chosen": [],
584
+ "logps/rejected": [],
585
+ "val/contain_eos_token": [],
586
+ "beta": [],
587
+ }
588
+ if self.reward_model is not None:
589
+ self.stats["objective/rlhf_reward"] = []
590
+ self.stats["objective/scores_margin"] = []
591
+ self.stats["objective/scores"] = []
592
+
593
+ if args.use_vllm:
594
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
595
+ n=2,
596
+ max_tokens=args.max_new_tokens,
597
+ temperature=args.temperature,
598
+ top_k=50,
599
+ top_p=1.0,
600
+ detokenize=False,
601
+ **getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),
602
+ )
603
+ else:
604
+ self.generation_config = GenerationConfig(
605
+ max_new_tokens=args.max_new_tokens,
606
+ temperature=args.temperature,
607
+ top_k=50,
608
+ top_p=1.0,
609
+ do_sample=True,
610
+ use_cache=False if args.gradient_checkpointing else True,
611
+ )
612
+
613
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
614
+ # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
615
+ # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
616
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
617
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
618
+ # that the warning has already been issued.
619
+ model.warnings_issued["estimate_tokens"] = True
620
+
621
+ super().__init__(
622
+ model=model,
623
+ args=args,
624
+ data_collator=data_collator,
625
+ train_dataset=train_dataset,
626
+ eval_dataset=eval_dataset,
627
+ processing_class=processing_class,
628
+ compute_metrics=compute_metrics,
629
+ callbacks=callbacks,
630
+ optimizers=optimizers,
631
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
632
+ )
633
+
634
+ # Add tags for models that have been loaded with the correct transformers version
635
+ if hasattr(self.model, "add_model_tags"):
636
+ self.model.add_model_tags(self._tag_names)
637
+
638
+ self._beta = args.beta
639
+
640
+ # Placed after the super[].__init__ because we need self.is_deepspeed_enabled and self.accelerator
641
+ if self.is_deepspeed_enabled:
642
+ if self.reward_model is not None:
643
+ self.reward_model = prepare_deepspeed(
644
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
645
+ )
646
+ if self.ref_model is not None:
647
+ self.ref_model = prepare_deepspeed(
648
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
649
+ )
650
+ else:
651
+ if self.ref_model is not None:
652
+ self.ref_model = self.ref_model.to(self.accelerator.device)
653
+ if self.reward_model is not None:
654
+ self.reward_model = self.reward_model.to(self.accelerator.device)
655
+
656
+ @property
657
+ def beta(self):
658
+ if isinstance(self._beta, list):
659
+ epoch = self.state.epoch
660
+ return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
661
+ else:
662
+ return self._beta
663
+
664
+ @staticmethod
665
+ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
666
+ """Tokenize a single row from a DPO specific dataset."""
667
+ if not is_encoder_decoder:
668
+ batch = tokenizer(feature["prompt"], add_special_tokens=False)
669
+ # Add BOS token to head of prompt. Avoid adding if it's already there
670
+ if tokenizer.bos_token_id is not None:
671
+ prompt_len_input_ids = len(batch["input_ids"])
672
+ if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
673
+ batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
674
+ batch["attention_mask"] = [1] + batch["attention_mask"]
675
+ else:
676
+ batch = tokenizer(feature["prompt"], add_special_tokens=True)
677
+ batch = {f"prompt_{key}": value for key, value in batch.items()}
678
+ return batch
679
+
680
+ # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
681
+ @wraps(Trainer.get_train_dataloader)
682
+ def get_train_dataloader(self) -> DataLoader:
683
+ if self.train_dataset is None:
684
+ raise ValueError("Trainer: training requires a train_dataset.")
685
+
686
+ train_dataset = self.train_dataset
687
+ data_collator = self.data_collator
688
+ dataloader_params = {
689
+ "batch_size": self._train_batch_size,
690
+ "collate_fn": data_collator,
691
+ "num_workers": self.args.dataloader_num_workers,
692
+ "pin_memory": self.args.dataloader_pin_memory,
693
+ "persistent_workers": self.args.dataloader_persistent_workers,
694
+ }
695
+
696
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
697
+ dataloader_params["sampler"] = self._get_train_sampler()
698
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
699
+ dataloader_params["worker_init_fn"] = seed_worker
700
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
701
+
702
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
703
+
704
+ # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
705
+ @wraps(Trainer.get_eval_dataloader)
706
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
707
+ if eval_dataset is None and self.eval_dataset is None:
708
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
709
+
710
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
711
+ # don't change during training
712
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
713
+ if (
714
+ hasattr(self, "_eval_dataloaders")
715
+ and dataloader_key in self._eval_dataloaders
716
+ and self.args.dataloader_persistent_workers
717
+ ):
718
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
719
+
720
+ eval_dataset = (
721
+ self.eval_dataset[eval_dataset]
722
+ if isinstance(eval_dataset, str)
723
+ else eval_dataset
724
+ if eval_dataset is not None
725
+ else self.eval_dataset
726
+ )
727
+ data_collator = self.data_collator
728
+
729
+ dataloader_params = {
730
+ "batch_size": self.args.eval_batch_size,
731
+ "collate_fn": data_collator,
732
+ "num_workers": self.args.dataloader_num_workers,
733
+ "pin_memory": self.args.dataloader_pin_memory,
734
+ "persistent_workers": self.args.dataloader_persistent_workers,
735
+ }
736
+
737
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
738
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
739
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
740
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
741
+
742
+ # accelerator.free_memory() will destroy the references, so
743
+ # we need to store the non-prepared version
744
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
745
+ if self.args.dataloader_persistent_workers:
746
+ if hasattr(self, "_eval_dataloaders"):
747
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
748
+ else:
749
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
750
+
751
+ return self.accelerator.prepare(eval_dataloader)
752
+
753
+ def _generate_vllm(self, model, prompts):
754
+ eos_token_id = self.processing_class.eos_token_id
755
+ pad_token_id = self.processing_class.pad_token_id
756
+
757
+ # Load the latest weights
758
+
759
+ pass
760
+
761
+ pass
762
+
763
+ if is_conversational({"prompt": prompts[0]}):
764
+ outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
765
+ else:
766
+ outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
767
+
768
+ completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
769
+ prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
770
+
771
+ # Create mask and pad the prompt and completion
772
+ max_prompt_length = max(len(ids) for ids in prompt_ids)
773
+ prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
774
+ prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
775
+ max_tokens = self.generation_config.max_tokens
776
+ completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
777
+ completion_ids = [
778
+ ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
779
+ for ids in completion_ids
780
+ ]
781
+ completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
782
+
783
+ # Convert to tensors
784
+ prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
785
+ prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
786
+ completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
787
+ completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
788
+
789
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
790
+
791
+ def _generate(self, model, prompts):
792
+ eos_token_id = self.processing_class.eos_token_id
793
+ pad_token_id = self.processing_class.pad_token_id
794
+
795
+ # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
796
+ # policies with different tokenizers / chat templates.
797
+ inputs = [{"prompt": prompt} for prompt in prompts]
798
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
799
+ inputs = [self.tokenize_row(x, self.is_encoder_decoder, self.processing_class) for x in inputs]
800
+ inputs = self.data_collator(inputs)
801
+
802
+ # Sample 2 completions per prompt of size `max_new_tokens` from the model
803
+ inputs = self._prepare_inputs(inputs)
804
+ prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
805
+ prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
806
+ with unwrap_model_for_generation(
807
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
808
+ ) as unwrapped_model:
809
+ output = unwrapped_model.generate(
810
+ input_ids=prompt_ids,
811
+ attention_mask=prompt_mask,
812
+ generation_config=self.generation_config,
813
+ )
814
+
815
+ completion_ids = output[:, prompt_ids.size(1) :]
816
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
817
+
818
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
819
+
820
+ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
821
+ # Get the number of tokens to truncate from prompt
822
+ num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
823
+
824
+ # Truncate left to avoid oom
825
+ prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
826
+ prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
827
+
828
+ # Concat the prompt and completion
829
+ prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
830
+ prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
831
+
832
+ # Get the logprobs of the completions from the model
833
+ output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
834
+
835
+ # There is 1 offset, because the model predict the next token
836
+ prompt_len = prompt_ids.size(1)
837
+ start_idx = prompt_len - 1 if prompt_len > 0 else 0
838
+ logits = output.logits[:, start_idx:-1]
839
+
840
+ # Take the completion tokens logprob
841
+ logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
842
+ return logprobs
843
+
844
+ def training_step(
845
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
846
+ ) -> torch.Tensor:
847
+ model.train()
848
+
849
+ prompts = inputs["prompt"]
850
+ batch_size = len(prompts)
851
+
852
+ if self.args.use_vllm:
853
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
854
+ else:
855
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
856
+
857
+ contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
858
+
859
+ logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
860
+ with torch.no_grad():
861
+ if self.ref_model is not None:
862
+ ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
863
+ else: # peft case: we just need to disable the adapter
864
+ with self.model.disable_adapter():
865
+ ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
866
+
867
+ # Decode the completions, and format them if the input is conversational
868
+ device = logprobs.device
869
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
870
+ if is_conversational({"prompt": prompts[0]}):
871
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
872
+
873
+ # Get the reward from the reward model or judge
874
+ if self.judge is not None:
875
+ # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
876
+ # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
877
+ # independent of the model's chat template, we use the raw conversation data, and apply our own chat
878
+ # template to it.
879
+ if is_conversational({"prompt": prompts[0]}):
880
+ environment = jinja2.Environment()
881
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
882
+ prompts = [template.render(messages=prompt) for prompt in prompts]
883
+ completions = [template.render(messages=completion) for completion in completions]
884
+
885
+ ranks_of_first_completion = self.judge.judge(
886
+ prompts, list(zip(completions[:batch_size], completions[batch_size:]))
887
+ )
888
+
889
+ # convert ranks to a True/False mask:
890
+ # when rank == 0, it means the first completion is the best
891
+ # when rank == 1, it means the second completion is the best
892
+ mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
893
+ else:
894
+ # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
895
+ # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
896
+ prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
897
+ if is_conversational({"prompt": prompts[0]}):
898
+ examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
899
+ examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
900
+ prompts = [example["prompt"] for example in examples]
901
+ completions = [example["completion"] for example in examples]
902
+
903
+ # Tokenize the prompts
904
+ prompts_ids = self.reward_processing_class(
905
+ prompts, padding=True, return_tensors="pt", padding_side="left"
906
+ )["input_ids"].to(device)
907
+ context_length = prompts_ids.shape[1]
908
+
909
+ # Tokenize the completions
910
+ completions_ids = self.reward_processing_class(
911
+ completions, padding=True, return_tensors="pt", padding_side="right"
912
+ )["input_ids"].to(device)
913
+
914
+ # Concatenate the prompts and completions and get the reward
915
+ prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
916
+ with torch.inference_mode():
917
+ _, scores, _ = get_reward(
918
+ self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
919
+ )
920
+
921
+ # Filter completion. Ensure that the sample contains stop_token_id
922
+ # Completions not passing that filter will receive a lower score.
923
+ if self.args.missing_eos_penalty is not None:
924
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
925
+
926
+ # Split the scores in 2 (the prompts of the first half are the same as the second half)
927
+ first_half, second_half = scores.split(batch_size)
928
+
929
+ # Get the indices of the chosen and rejected examples
930
+ mask = first_half >= second_half
931
+
932
+ batch_range = torch.arange(batch_size, device=device)
933
+ chosen_indices = batch_range + (~mask * batch_size)
934
+ rejected_indices = batch_range + (mask * batch_size)
935
+
936
+ # Build tensor so that the first half is the chosen examples and the second half the rejected examples
937
+ cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
938
+ cr_logprobs = logprobs[cr_indices]
939
+ cr_ref_logprobs = ref_logprobs[cr_indices]
940
+
941
+ # mask out the padding tokens
942
+ padding_mask = ~completion_mask.bool()
943
+ cr_padding_mask = padding_mask[cr_indices]
944
+
945
+ cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
946
+ cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
947
+
948
+ # Split the chosen and rejected examples
949
+ chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
950
+ chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
951
+ pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
952
+ ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
953
+
954
+ logits = pi_logratios - ref_logratios
955
+
956
+ if self.args.loss_type == "sigmoid":
957
+ losses = -F.logsigmoid(self.beta * logits)
958
+ elif self.args.loss_type == "ipo":
959
+ losses = (logits - 1 / (2 * self.beta)) ** 2
960
+ else:
961
+ raise NotImplementedError(f"invalid loss type {self.loss_type}")
962
+
963
+ loss = losses.mean()
964
+
965
+ # Log everything
966
+ if self.reward_model is not None:
967
+ scores_margin = scores[chosen_indices] - scores[rejected_indices]
968
+ self.stats["objective/scores_margin"].append(
969
+ self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
970
+ )
971
+ self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
972
+ self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
973
+ self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
974
+ self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
975
+
976
+ kl = logprobs - ref_logprobs
977
+ mean_kl = kl.sum(1).mean()
978
+ self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
979
+ non_score_reward = (-self.beta * kl).sum(1)
980
+ mean_non_score_reward = non_score_reward.mean()
981
+ self.stats["objective/non_score_reward"].append(
982
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
983
+ )
984
+ if self.reward_model is not None:
985
+ rlhf_reward = scores + non_score_reward
986
+ self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
987
+ mean_entropy = -logprobs.sum(1).mean()
988
+ self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
989
+ chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
990
+ gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
991
+ self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
992
+ rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
993
+ gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
994
+ self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
995
+ margin = gathered_chosen_rewards - gathered_rejected_rewards
996
+ self.stats["rewards/margins"].append(margin.mean().item())
997
+ accuracy = margin > 0
998
+ self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
999
+ self.stats["beta"].append(self.beta)
1000
+
1001
+ if (
1002
+ self.args.torch_empty_cache_steps is not None
1003
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
1004
+ ):
1005
+ empty_cache()
1006
+
1007
+ kwargs = {}
1008
+
1009
+ # For LOMO optimizers you need to explicitly use the learning rate
1010
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
1011
+ kwargs["learning_rate"] = self._get_learning_rate()
1012
+
1013
+ if self.args.n_gpu > 1:
1014
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
1015
+
1016
+ if self.use_apex:
1017
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
1018
+ scaled_loss.backward()
1019
+ else:
1020
+ self.accelerator.backward(loss, **kwargs)
1021
+
1022
+ return loss.detach() / self.args.gradient_accumulation_steps
1023
+
1024
+ # Same as Trainer._maybe_log_save_evaluate but log our metrics
1025
+ def _maybe_log_save_evaluate(
1026
+ self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
1027
+ ):
1028
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
1029
+ logs: dict[str, float] = {}
1030
+
1031
+ # all_gather + mean() to get average loss over all processes
1032
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
1033
+
1034
+ # reset tr_loss to zero
1035
+ tr_loss -= tr_loss
1036
+
1037
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
1038
+ if grad_norm is not None:
1039
+ logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
1040
+ if learning_rate is not None:
1041
+ logs["learning_rate"] = learning_rate
1042
+ else:
1043
+ logs["learning_rate"] = self._get_learning_rate()
1044
+
1045
+ # Add our metrics
1046
+ for key, val in self.stats.items():
1047
+ logs[key] = sum(val) / len(val)
1048
+ self.stats = {key: [] for key in self.stats} # reset stats
1049
+
1050
+ self._total_loss_scalar += tr_loss_scalar
1051
+ self._globalstep_last_logged = self.state.global_step
1052
+ self.store_flos()
1053
+ self.log(logs, start_time)
1054
+
1055
+ metrics = None
1056
+ if self.control.should_evaluate:
1057
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
1058
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
1059
+
1060
+ if self.args.save_strategy == "best":
1061
+ self.control.should_save = is_new_best_metric
1062
+
1063
+ if self.control.should_save:
1064
+ self._save_checkpoint(model, trial)
1065
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1066
+
1067
+ # Ensure the model card is saved along with the checkpoint
1068
+ def _save_checkpoint(self, model, trial):
1069
+ if self.args.hub_model_id is None:
1070
+ model_name = Path(self.args.output_dir).name
1071
+ else:
1072
+ model_name = self.args.hub_model_id.split("/")[-1]
1073
+ self.create_model_card(model_name=model_name)
1074
+ super()._save_checkpoint(model, trial)
1075
+
1076
+ def create_model_card(
1077
+ self,
1078
+ model_name: Optional[str] = None,
1079
+ dataset_name: Optional[str] = None,
1080
+ tags: Union[str, list[str], None] = None,
1081
+ ):
1082
+ """
1083
+ Creates a draft of a model card using the information available to the `Trainer`.
1084
+
1085
+ Args:
1086
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1087
+ Name of the model.
1088
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1089
+ Name of the dataset used for training.
1090
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1091
+ Tags to be associated with the model card.
1092
+ """
1093
+ if not self.is_world_process_zero():
1094
+ return
1095
+
1096
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1097
+ base_model = self.model.config._name_or_path
1098
+ else:
1099
+ base_model = None
1100
+
1101
+ # normalize `tags` to a mutable set
1102
+ if tags is None:
1103
+ tags = set()
1104
+ elif isinstance(tags, str):
1105
+ tags = {tags}
1106
+ else:
1107
+ tags = set(tags)
1108
+
1109
+ if hasattr(self.model.config, "unsloth_version"):
1110
+ tags.add("unsloth")
1111
+
1112
+ tags.update(self._tag_names)
1113
+
1114
+ citation = textwrap.dedent("""\
1115
+ @article{guo2024direct,
1116
+ title = {{Direct Language Model Alignment from Online AI Feedback}},
1117
+ author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1118
+ year = 2024,
1119
+ eprint = {arXiv:2402.04792}
1120
+ }""")
1121
+
1122
+ model_card = generate_model_card(
1123
+ base_model=base_model,
1124
+ model_name=model_name,
1125
+ hub_model_id=self.hub_model_id,
1126
+ dataset_name=dataset_name,
1127
+ tags=tags,
1128
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1129
+ comet_url=get_comet_experiment_url(),
1130
+ trainer_name="Online DPO",
1131
+ trainer_citation=citation,
1132
+ paper_title="Direct Language Model Alignment from Online AI Feedback",
1133
+ paper_id="2402.04792",
1134
+ )
1135
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1136
+ class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1137
+ """
1138
+
1139
+ Initialize OnlineDPOTrainer.
1140
+
1141
+ Args:
1142
+ model (`Union[str, nn.Module, PreTrainedModel]`):
1143
+ Model to be trained. Can be either:
1144
+
1145
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
1146
+ path to a *directory* containing model weights saved using
1147
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1148
+ using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
1149
+ `args.model_init_kwargs`.
1150
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1151
+ ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1152
+ The reference model to use for training. If None is specified, the reference model will be created from the
1153
+ model.
1154
+ reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1155
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1156
+ judge (`BasePairwiseJudge`):
1157
+ The judge to use for pairwise comparison of model completions.
1158
+ args (`OnlineDPOConfig`):
1159
+ The online DPO config arguments to use for training.
1160
+ data_collator (`transformers.DataCollator`):
1161
+ The data collator to use for training. If None is specified, the default data collator
1162
+ (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
1163
+ sequences in the batch, given a dataset of paired sequences.
1164
+ train_dataset (`datasets.Dataset`):
1165
+ The dataset to use for training.
1166
+ eval_dataset (`datasets.Dataset`):
1167
+ The dataset to use for evaluation.
1168
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
1169
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1170
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1171
+ reuse the fine-tuned model.
1172
+ peft_config (`dict`):
1173
+ The peft config to use for training.
1174
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1175
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1176
+ metric values.
1177
+ callbacks (`list[transformers.TrainerCallback]`):
1178
+ The callbacks to use for training.
1179
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1180
+ The optimizer and scheduler to use for training.
1181
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1182
+ The function to use to preprocess the logits before computing the metrics.
1183
+
1184
+ """
1185
+ def __init__(
1186
+ self,
1187
+ model,
1188
+ ref_model = None,
1189
+ reward_model = None,
1190
+ judge = None,
1191
+ args = None,
1192
+ data_collator = None,
1193
+ train_dataset = None,
1194
+ eval_dataset = None,
1195
+ processing_class = None,
1196
+ reward_processing_class = None,
1197
+ peft_config = None,
1198
+ compute_metrics = None,
1199
+ callbacks = None,
1200
+ preprocess_logits_for_metrics = None,
1201
+ **kwargs
1202
+ ):
1203
+ if args is None: args = UnslothOnlineDPOConfig()
1204
+ use_bf16 = getattr(args, 'bf16', False)
1205
+ if type(use_bf16) is not bool: use_bf16 = False
1206
+ use_fp16 = getattr(args, 'fp16', False)
1207
+ if type(use_fp16) is not bool: use_fp16 = False
1208
+ force_float32 = False
1209
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1210
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1211
+ force_float32 = True
1212
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1213
+ dtype = getattr(model.config, 'torch_dtype', None)
1214
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1215
+ from unsloth_zoo.utils import _get_dtype
1216
+ dtype = _get_dtype(dtype)
1217
+ float16 = dtype == torch.float16
1218
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1219
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1220
+ if force_float32:
1221
+ args.fp16 = False
1222
+ args.bf16 = False
1223
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1224
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1225
+ args.fp16 = float16
1226
+ args.bf16 = not float16
1227
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1228
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1229
+ args.eval_strategy = 'steps'
1230
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1231
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1232
+ if ga_steps is not None and ga_steps > 1:
1233
+ from transformers import __version__ as transformers_version
1234
+ if Version(transformers_version) <= Version('4.45.2'):
1235
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1236
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1237
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1238
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1239
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1240
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1241
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1242
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1243
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1244
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1245
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1246
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1247
+ if force_float32:
1248
+ args.bf16_full_eval = False
1249
+ args.fp16_full_eval = False
1250
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1251
+ args.bf16_full_eval = True
1252
+ args.fp16_full_eval = False
1253
+ elif not bf16_full_eval and not fp16_full_eval:
1254
+ args.bf16_full_eval = args.bf16
1255
+ args.fp16_full_eval = args.fp16
1256
+ _output_logits = False
1257
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1258
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1259
+ if _output_logits:
1260
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1261
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1262
+ pass
1263
+ else:
1264
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1265
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1266
+ if args_max_seq_length is None and model_max_seq_length is not None:
1267
+ max_seq_length = model.max_seq_length
1268
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1269
+ if model is not None and hasattr(model, 'for_training'):
1270
+ model.for_training()
1271
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1272
+ if 'processing_class' in locals():
1273
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1274
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1275
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1276
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1277
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1278
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1279
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1280
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1281
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1282
+ else:
1283
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1284
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1285
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1286
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1287
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1288
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1289
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1290
+ else:
1291
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1292
+ other_metrics = []
1293
+
1294
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1295
+ PatchRLStatistics('online_dpo_trainer', other_metrics)
1296
+
1297
+ super().__init__(
1298
+ model = model,
1299
+ ref_model = ref_model,
1300
+ reward_model = reward_model,
1301
+ judge = judge,
1302
+ args = args,
1303
+ data_collator = data_collator,
1304
+ train_dataset = train_dataset,
1305
+ eval_dataset = eval_dataset,
1306
+ processing_class = processing_class,
1307
+ reward_processing_class = reward_processing_class,
1308
+ peft_config = peft_config,
1309
+ compute_metrics = compute_metrics,
1310
+ callbacks = callbacks,
1311
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1312
+ if hasattr(self, 'neftune_hook_handle'):
1313
+ self.neftune_hook_handle.remove()
1314
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1315
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1316
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1317
+ pass
1318
+ if hasattr(self, 'accelerator'):
1319
+ scaler = self.accelerator.scaler
1320
+ current_model = model
1321
+ while hasattr(current_model, 'model'):
1322
+ current_model.accelerator_scaler = scaler
1323
+ current_model = current_model.model
1324
+ current_model.accelerator_scaler = scaler
1325
+ pass
1326
+
1327
+ pass
unsloth_compiled_cache/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, Path, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_rich_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb, Optional, PeftModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothPPOConfig(PPOConfig):
53
+ """
54
+
55
+ Configuration class for the [`PPOTrainer`].
56
+
57
+ This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
59
+ values in this class may differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
67
+ Name of this experiment.
68
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
69
+ Path to the reward model.
70
+ model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
71
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
72
+ ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
73
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
74
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
75
+ Number of epochs to train.
76
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
77
+ Whether to whiten the rewards.
78
+ kl_coef (`float`, *optional*, defaults to `0.05`):
79
+ KL coefficient.
80
+ kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
81
+ Which estimator for KL-Divergence to use from [Approximating KL
82
+ Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
83
+ estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
84
+ better estimator". Cannot be set to "k2", as it is used for logging purposes.
85
+ cliprange (`float`, *optional*, defaults to `0.2`):
86
+ Clip range.
87
+ vf_coef (`float`, *optional*, defaults to `0.1`):
88
+ Value function coefficient.
89
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
90
+ Clip range for the value function.
91
+ gamma (`float`, *optional*, defaults to `1.0`):
92
+ Discount factor.
93
+ lam (`float`, *optional*, defaults to `0.95`):
94
+ Lambda value for GAE.
95
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
96
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
97
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
98
+ capacity of a single GPU, albeit at the cost of slower generation.
99
+
100
+ """
101
+ vllm_sampling_params: Optional[Any] = field(
102
+ default = None,
103
+ metadata = {'help': 'vLLM SamplingParams'},
104
+ )
105
+ unsloth_num_chunks : Optional[int] = field(
106
+ default = -1,
107
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
108
+ )
109
+
110
+ def __init__(
111
+ self,
112
+ output_dir = None,
113
+ overwrite_output_dir = None,
114
+ do_train = False,
115
+ do_eval = False,
116
+ do_predict = False,
117
+ eval_strategy = 'no',
118
+ prediction_loss_only = False,
119
+ per_device_train_batch_size = 4,
120
+ per_device_eval_batch_size = 4,
121
+ per_gpu_train_batch_size = None,
122
+ per_gpu_eval_batch_size = None,
123
+ gradient_accumulation_steps = 2,
124
+ eval_accumulation_steps = 2,
125
+ eval_delay = 0,
126
+ torch_empty_cache_steps = 250,
127
+ learning_rate = 5e-05,
128
+ weight_decay = 0.01,
129
+ adam_beta1 = 0.9,
130
+ adam_beta2 = 0.999,
131
+ adam_epsilon = 1e-08,
132
+ max_grad_norm = 1.0,
133
+ num_train_epochs = 3.0,
134
+ max_steps = -1,
135
+ lr_scheduler_type = 'linear',
136
+ warmup_ratio = 0.1,
137
+ warmup_steps = 0,
138
+ log_level = 'passive',
139
+ log_level_replica = 'warning',
140
+ log_on_each_node = True,
141
+ logging_dir = None,
142
+ logging_strategy = 'steps',
143
+ logging_first_step = False,
144
+ logging_steps = 1,
145
+ logging_nan_inf_filter = False,
146
+ save_strategy = 'steps',
147
+ save_steps = 500,
148
+ save_total_limit = None,
149
+ save_safetensors = True,
150
+ save_on_each_node = False,
151
+ save_only_model = False,
152
+ restore_callback_states_from_checkpoint = False,
153
+ no_cuda = False,
154
+ use_cpu = False,
155
+ use_mps_device = False,
156
+ seed = 3407,
157
+ data_seed = 3407,
158
+ jit_mode_eval = False,
159
+ use_ipex = False,
160
+ bf16 = False,
161
+ fp16 = False,
162
+ fp16_opt_level = 'O1',
163
+ half_precision_backend = 'auto',
164
+ bf16_full_eval = False,
165
+ fp16_full_eval = False,
166
+ tf32 = None,
167
+ local_rank = -1,
168
+ ddp_backend = None,
169
+ tpu_num_cores = None,
170
+ tpu_metrics_debug = False,
171
+ debug = '',
172
+ dataloader_drop_last = False,
173
+ eval_steps = None,
174
+ dataloader_num_workers = 0,
175
+ dataloader_prefetch_factor = None,
176
+ past_index = -1,
177
+ run_name = None,
178
+ disable_tqdm = None,
179
+ remove_unused_columns = True,
180
+ label_names = None,
181
+ load_best_model_at_end = False,
182
+ metric_for_best_model = None,
183
+ greater_is_better = None,
184
+ ignore_data_skip = False,
185
+ fsdp = '',
186
+ fsdp_min_num_params = 0,
187
+ fsdp_config = None,
188
+ fsdp_transformer_layer_cls_to_wrap = None,
189
+ accelerator_config = None,
190
+ deepspeed = None,
191
+ label_smoothing_factor = 0.0,
192
+ optim = 'adamw_8bit',
193
+ optim_args = None,
194
+ adafactor = False,
195
+ group_by_length = False,
196
+ length_column_name = 'length',
197
+ report_to = None,
198
+ ddp_find_unused_parameters = None,
199
+ ddp_bucket_cap_mb = None,
200
+ ddp_broadcast_buffers = None,
201
+ dataloader_pin_memory = True,
202
+ dataloader_persistent_workers = False,
203
+ skip_memory_metrics = True,
204
+ use_legacy_prediction_loop = False,
205
+ push_to_hub = False,
206
+ resume_from_checkpoint = None,
207
+ hub_model_id = None,
208
+ hub_strategy = 'every_save',
209
+ hub_token = None,
210
+ hub_private_repo = None,
211
+ hub_always_push = False,
212
+ hub_revision = None,
213
+ gradient_checkpointing = False,
214
+ gradient_checkpointing_kwargs = None,
215
+ include_inputs_for_metrics = False,
216
+ eval_do_concat_batches = True,
217
+ fp16_backend = 'auto',
218
+ push_to_hub_model_id = None,
219
+ push_to_hub_organization = None,
220
+ push_to_hub_token = None,
221
+ mp_parameters = '',
222
+ auto_find_batch_size = True,
223
+ full_determinism = False,
224
+ torchdynamo = None,
225
+ ray_scope = 'last',
226
+ ddp_timeout = 1800,
227
+ torch_compile = False,
228
+ torch_compile_backend = None,
229
+ torch_compile_mode = None,
230
+ include_tokens_per_second = False,
231
+ include_num_input_tokens_seen = False,
232
+ neftune_noise_alpha = None,
233
+ optim_target_modules = None,
234
+ batch_eval_metrics = False,
235
+ eval_on_start = False,
236
+ use_liger_kernel = False,
237
+ liger_kernel_config = None,
238
+ eval_use_gather_object = False,
239
+ average_tokens_across_devices = True,
240
+ dataset_num_proc = None,
241
+ num_mini_batches = 1,
242
+ total_episodes = None,
243
+ local_rollout_forward_batch_size = 64,
244
+ num_sample_generations = 10,
245
+ response_length = 53,
246
+ stop_token = None,
247
+ stop_token_id = None,
248
+ temperature = 0.7,
249
+ missing_eos_penalty = None,
250
+ sft_model_path = 'EleutherAI/pythia-160m',
251
+ world_size = None,
252
+ num_total_batches = None,
253
+ micro_batch_size = None,
254
+ local_batch_size = None,
255
+ batch_size = None,
256
+ local_mini_batch_size = None,
257
+ mini_batch_size = None,
258
+ exp_name = 'ppo_config',
259
+ reward_model_path = 'EleutherAI/pythia-160m',
260
+ model_adapter_name = None,
261
+ ref_adapter_name = None,
262
+ num_ppo_epochs = 4,
263
+ whiten_rewards = False,
264
+ kl_coef = 0.05,
265
+ kl_estimator = 'k1',
266
+ cliprange = 0.2,
267
+ vf_coef = 0.1,
268
+ cliprange_value = 0.2,
269
+ gamma = 1.0,
270
+ lam = 0.95,
271
+ ds3_gather_for_generation = True,
272
+ vllm_sampling_params = None,
273
+ unsloth_num_chunks = -1,
274
+
275
+ **kwargs,
276
+ ):
277
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
278
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
279
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
280
+ output_dir = 'unsloth_training_checkpoints'
281
+ save_strategy = 'no'
282
+ if dataset_num_proc is None:
283
+ from multiprocessing import cpu_count
284
+ dataset_num_proc = min(cpu_count()*2, 2)
285
+ if temperature <= 0:
286
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
287
+ elif temperature >= 10:
288
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
289
+
290
+
291
+ super().__init__(
292
+ output_dir = output_dir,
293
+ overwrite_output_dir = overwrite_output_dir,
294
+ do_train = do_train,
295
+ do_eval = do_eval,
296
+ do_predict = do_predict,
297
+ eval_strategy = eval_strategy,
298
+ prediction_loss_only = prediction_loss_only,
299
+ per_device_train_batch_size = per_device_train_batch_size,
300
+ per_device_eval_batch_size = per_device_eval_batch_size,
301
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
302
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
303
+ gradient_accumulation_steps = gradient_accumulation_steps,
304
+ eval_accumulation_steps = eval_accumulation_steps,
305
+ eval_delay = eval_delay,
306
+ torch_empty_cache_steps = torch_empty_cache_steps,
307
+ learning_rate = learning_rate,
308
+ weight_decay = weight_decay,
309
+ adam_beta1 = adam_beta1,
310
+ adam_beta2 = adam_beta2,
311
+ adam_epsilon = adam_epsilon,
312
+ max_grad_norm = max_grad_norm,
313
+ num_train_epochs = num_train_epochs,
314
+ max_steps = max_steps,
315
+ lr_scheduler_type = lr_scheduler_type,
316
+ warmup_ratio = warmup_ratio,
317
+ warmup_steps = warmup_steps,
318
+ log_level = log_level,
319
+ log_level_replica = log_level_replica,
320
+ log_on_each_node = log_on_each_node,
321
+ logging_dir = logging_dir,
322
+ logging_strategy = logging_strategy,
323
+ logging_first_step = logging_first_step,
324
+ logging_steps = logging_steps,
325
+ logging_nan_inf_filter = logging_nan_inf_filter,
326
+ save_strategy = save_strategy,
327
+ save_steps = save_steps,
328
+ save_total_limit = save_total_limit,
329
+ save_safetensors = save_safetensors,
330
+ save_on_each_node = save_on_each_node,
331
+ save_only_model = save_only_model,
332
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
333
+ no_cuda = no_cuda,
334
+ use_cpu = use_cpu,
335
+ use_mps_device = use_mps_device,
336
+ seed = seed,
337
+ data_seed = data_seed,
338
+ jit_mode_eval = jit_mode_eval,
339
+ use_ipex = use_ipex,
340
+ bf16 = bf16,
341
+ fp16 = fp16,
342
+ fp16_opt_level = fp16_opt_level,
343
+ half_precision_backend = half_precision_backend,
344
+ bf16_full_eval = bf16_full_eval,
345
+ fp16_full_eval = fp16_full_eval,
346
+ tf32 = tf32,
347
+ local_rank = local_rank,
348
+ ddp_backend = ddp_backend,
349
+ tpu_num_cores = tpu_num_cores,
350
+ tpu_metrics_debug = tpu_metrics_debug,
351
+ debug = debug,
352
+ dataloader_drop_last = dataloader_drop_last,
353
+ eval_steps = eval_steps,
354
+ dataloader_num_workers = dataloader_num_workers,
355
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
356
+ past_index = past_index,
357
+ run_name = run_name,
358
+ disable_tqdm = disable_tqdm,
359
+ remove_unused_columns = remove_unused_columns,
360
+ label_names = label_names,
361
+ load_best_model_at_end = load_best_model_at_end,
362
+ metric_for_best_model = metric_for_best_model,
363
+ greater_is_better = greater_is_better,
364
+ ignore_data_skip = ignore_data_skip,
365
+ fsdp = fsdp,
366
+ fsdp_min_num_params = fsdp_min_num_params,
367
+ fsdp_config = fsdp_config,
368
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
369
+ accelerator_config = accelerator_config,
370
+ deepspeed = deepspeed,
371
+ label_smoothing_factor = label_smoothing_factor,
372
+ optim = optim,
373
+ optim_args = optim_args,
374
+ adafactor = adafactor,
375
+ group_by_length = group_by_length,
376
+ length_column_name = length_column_name,
377
+ report_to = report_to,
378
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
379
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
380
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
381
+ dataloader_pin_memory = dataloader_pin_memory,
382
+ dataloader_persistent_workers = dataloader_persistent_workers,
383
+ skip_memory_metrics = skip_memory_metrics,
384
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
385
+ push_to_hub = push_to_hub,
386
+ resume_from_checkpoint = resume_from_checkpoint,
387
+ hub_model_id = hub_model_id,
388
+ hub_strategy = hub_strategy,
389
+ hub_token = hub_token,
390
+ hub_private_repo = hub_private_repo,
391
+ hub_always_push = hub_always_push,
392
+ hub_revision = hub_revision,
393
+ gradient_checkpointing = gradient_checkpointing,
394
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
395
+ include_inputs_for_metrics = include_inputs_for_metrics,
396
+ eval_do_concat_batches = eval_do_concat_batches,
397
+ fp16_backend = fp16_backend,
398
+ push_to_hub_model_id = push_to_hub_model_id,
399
+ push_to_hub_organization = push_to_hub_organization,
400
+ push_to_hub_token = push_to_hub_token,
401
+ mp_parameters = mp_parameters,
402
+ auto_find_batch_size = auto_find_batch_size,
403
+ full_determinism = full_determinism,
404
+ torchdynamo = torchdynamo,
405
+ ray_scope = ray_scope,
406
+ ddp_timeout = ddp_timeout,
407
+ torch_compile = torch_compile,
408
+ torch_compile_backend = torch_compile_backend,
409
+ torch_compile_mode = torch_compile_mode,
410
+ include_tokens_per_second = include_tokens_per_second,
411
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
412
+ neftune_noise_alpha = neftune_noise_alpha,
413
+ optim_target_modules = optim_target_modules,
414
+ batch_eval_metrics = batch_eval_metrics,
415
+ eval_on_start = eval_on_start,
416
+ use_liger_kernel = use_liger_kernel,
417
+ liger_kernel_config = liger_kernel_config,
418
+ eval_use_gather_object = eval_use_gather_object,
419
+ average_tokens_across_devices = average_tokens_across_devices,
420
+ dataset_num_proc = dataset_num_proc,
421
+ num_mini_batches = num_mini_batches,
422
+ total_episodes = total_episodes,
423
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
424
+ num_sample_generations = num_sample_generations,
425
+ response_length = response_length,
426
+ stop_token = stop_token,
427
+ stop_token_id = stop_token_id,
428
+ temperature = temperature,
429
+ missing_eos_penalty = missing_eos_penalty,
430
+ sft_model_path = sft_model_path,
431
+ world_size = world_size,
432
+ num_total_batches = num_total_batches,
433
+ micro_batch_size = micro_batch_size,
434
+ local_batch_size = local_batch_size,
435
+ batch_size = batch_size,
436
+ local_mini_batch_size = local_mini_batch_size,
437
+ mini_batch_size = mini_batch_size,
438
+ exp_name = exp_name,
439
+ reward_model_path = reward_model_path,
440
+ model_adapter_name = model_adapter_name,
441
+ ref_adapter_name = ref_adapter_name,
442
+ num_ppo_epochs = num_ppo_epochs,
443
+ whiten_rewards = whiten_rewards,
444
+ kl_coef = kl_coef,
445
+ kl_estimator = kl_estimator,
446
+ cliprange = cliprange,
447
+ vf_coef = vf_coef,
448
+ cliprange_value = cliprange_value,
449
+ gamma = gamma,
450
+ lam = lam,
451
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
452
+ self.vllm_sampling_params = vllm_sampling_params
453
+ self.unsloth_num_chunks = unsloth_num_chunks
454
+
455
+ pass
456
+
457
+ class _UnslothPPOTrainer(Trainer):
458
+ _tag_names = ["trl", "ppo"]
459
+
460
+ def __init__(
461
+ self,
462
+ args: PPOConfig,
463
+ processing_class: Optional[
464
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
465
+ ],
466
+ model: nn.Module,
467
+ ref_model: Optional[nn.Module],
468
+ reward_model: nn.Module,
469
+ train_dataset: Dataset,
470
+ value_model: nn.Module,
471
+ data_collator: Optional[DataCollatorWithPadding] = None,
472
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
473
+ # less commonly used
474
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
475
+ callbacks: Optional[list[TrainerCallback]] = None,
476
+ peft_config: Optional["PeftConfig"] = None,
477
+ ) -> None:
478
+ if ref_model is model:
479
+ raise ValueError(
480
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
481
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
482
+ )
483
+
484
+ self.args = args
485
+ self.processing_class = processing_class
486
+ self.policy_model = model
487
+
488
+ # Define the collator if not provided
489
+ if data_collator is None:
490
+ data_collator = DataCollatorWithPadding(self.processing_class)
491
+
492
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
493
+ if args.stop_token and args.stop_token_id:
494
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
495
+ elif args.stop_token:
496
+ if args.stop_token == "eos":
497
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
498
+ else:
499
+ raise ValueError(
500
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
501
+ )
502
+ else:
503
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
504
+
505
+ # Check that the kl estimator is valid
506
+ if self.args.kl_estimator not in {"k1", "k3"}:
507
+ raise ValueError(
508
+ "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
509
+ "appears to be a strictly better estimator). See "
510
+ "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
511
+ )
512
+
513
+ # peft support
514
+ if not is_peft_available() and peft_config is not None:
515
+ raise ImportError(
516
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
517
+ )
518
+ elif is_peft_available() and peft_config is not None:
519
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
520
+ if isinstance(self.policy_model, PeftModel):
521
+ self.policy_model = self.policy_model.merge_and_unload()
522
+
523
+ # get peft model with the given config
524
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
525
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
526
+ peft_module_casting_to_bf16(self.policy_model)
527
+
528
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
529
+ self.model_adapter_name = args.model_adapter_name
530
+ self.ref_adapter_name = args.ref_adapter_name
531
+
532
+ if ref_model:
533
+ self.ref_model = ref_model
534
+ elif self.is_peft_model:
535
+ self.ref_model = None
536
+ else:
537
+ self.ref_model = create_reference_model(self.policy_model)
538
+
539
+ self.reward_model = reward_model
540
+ self.train_dataset = train_dataset
541
+ self.train_dataset_len = len(train_dataset)
542
+ self.value_model = value_model
543
+ self.data_collator = data_collator
544
+ self.eval_dataset = eval_dataset
545
+ self.optimizer, self.lr_scheduler = optimizers
546
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
547
+
548
+ #########
549
+ # calculate various batch sizes
550
+ #########
551
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
552
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
553
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
554
+ self.accelerator = accelerator
555
+ args.world_size = accelerator.num_processes
556
+ args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
557
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
558
+ args.batch_size = int(args.local_batch_size * args.world_size)
559
+ args.mini_batch_size = exact_div(
560
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
561
+ )
562
+ args.local_mini_batch_size = exact_div(
563
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
564
+ )
565
+ if args.whiten_rewards:
566
+ assert args.local_mini_batch_size >= 8, (
567
+ f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
568
+ )
569
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
570
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
571
+ args.num_total_batches = math.ceil(
572
+ args.total_episodes / args.batch_size
573
+ ) # we may train for more than `total_episodes`
574
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
575
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
576
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
577
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
578
+ if args.num_sample_generations > 0:
579
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
580
+ self.local_dataloader_batch_size = args.local_batch_size
581
+
582
+ #########
583
+ # setup model, optimizer, and others
584
+ #########
585
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
586
+ if module is not None:
587
+ disable_dropout_in_model(module)
588
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
589
+ self.model.config = self.policy_model.config # needed for pushing to hub
590
+ self.create_optimizer_and_scheduler(
591
+ num_training_steps=args.num_total_batches
592
+ ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
593
+
594
+ #########
595
+ ### trainer specifics
596
+ #########
597
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
598
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
599
+ self.callback_handler = CallbackHandler(
600
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
601
+ )
602
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
603
+ self.control = TrainerControl()
604
+ self.state = OnlineTrainerState(
605
+ is_local_process_zero=self.is_local_process_zero(),
606
+ is_world_process_zero=self.is_world_process_zero(),
607
+ stateful_callbacks=[
608
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
609
+ ],
610
+ )
611
+ self.current_flos = 0
612
+ self.hp_search_backend = None
613
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
614
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
615
+ # Create distant repo and output directory if needed
616
+ self.hub_model_id = None
617
+ if self.args.push_to_hub:
618
+ self.init_hf_repo()
619
+ if self.args.should_save:
620
+ os.makedirs(self.args.output_dir, exist_ok=True)
621
+
622
+ # Add tags for models that have been loaded with the correct transformers version
623
+ if hasattr(self.model, "add_model_tags"):
624
+ self.model.add_model_tags(self._tag_names)
625
+
626
+ #########
627
+ ### setup dataloader
628
+ #########
629
+ self.dataloader = DataLoader(
630
+ self.train_dataset,
631
+ batch_size=self.local_dataloader_batch_size,
632
+ shuffle=True,
633
+ collate_fn=self.data_collator,
634
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
635
+ )
636
+ # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
637
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
638
+ torch.manual_seed(args.seed)
639
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
640
+ torch.manual_seed(self.local_seed) # reset the local seed again
641
+
642
+ self.eval_dataloader = DataLoader(
643
+ self.eval_dataset,
644
+ batch_size=args.per_device_eval_batch_size,
645
+ collate_fn=self.data_collator,
646
+ drop_last=True,
647
+ ) # no need to shuffle eval dataset
648
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
649
+
650
+ if self.is_deepspeed_enabled:
651
+ self.reward_model = prepare_deepspeed(
652
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
653
+ )
654
+
655
+ if self.ref_model is None:
656
+ if not self.is_peft_model:
657
+ raise ValueError("No reference model and model is not a Peft model.")
658
+ else:
659
+ self.ref_model = prepare_deepspeed(
660
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
661
+ )
662
+ else:
663
+ if self.ref_model is None:
664
+ if not self.is_peft_model:
665
+ raise ValueError("No reference model and model is not a Peft model.")
666
+ else:
667
+ self.ref_model = self.ref_model.to(self.accelerator.device)
668
+ self.reward_model = self.reward_model.to(self.accelerator.device)
669
+
670
+ def get_train_dataloader(self) -> DataLoader:
671
+ return self.dataloader
672
+
673
+ def get_eval_dataloader(self) -> DataLoader:
674
+ return self.eval_dataloader
675
+
676
+ @contextmanager
677
+ def null_ref_context(self):
678
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
679
+ with (
680
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
681
+ if self.is_peft_model and not self.ref_adapter_name
682
+ else nullcontext()
683
+ ):
684
+ if self.ref_adapter_name:
685
+ self.model.policy.set_adapter(self.ref_adapter_name)
686
+ yield
687
+ if self.ref_adapter_name:
688
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
689
+
690
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
691
+ backup_model = self.model
692
+ self.model = self.model.policy # save only the policy
693
+
694
+ if self.is_deepspeed_enabled:
695
+ backup_deepspeed = self.deepspeed
696
+ self.deepspeed = self.model
697
+
698
+ super().save_model(output_dir, _internal_call)
699
+
700
+ self.model = backup_model
701
+
702
+ if self.is_deepspeed_enabled:
703
+ self.deepspeed = backup_deepspeed
704
+
705
+ def train(self):
706
+ args = self.args
707
+ accelerator = self.accelerator
708
+ optimizer = self.optimizer
709
+ model = self.model
710
+ ref_policy = self.ref_model
711
+ reward_model = self.reward_model
712
+ processing_class = self.processing_class
713
+ dataloader = self.dataloader
714
+ device = accelerator.device
715
+
716
+ def repeat_generator():
717
+ while True:
718
+ yield from dataloader
719
+
720
+ iter_dataloader = iter(repeat_generator())
721
+ generation_config = GenerationConfig(
722
+ max_new_tokens=args.response_length,
723
+ temperature=(args.temperature + 1e-7),
724
+ top_k=0.0,
725
+ top_p=1.0,
726
+ do_sample=True,
727
+ )
728
+
729
+ accelerator.print("===training policy===")
730
+ start_time = time.time()
731
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
732
+ approxkl_stats = torch.zeros(stats_shape, device=device)
733
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
734
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
735
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
736
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
737
+ entropy_stats = torch.zeros(stats_shape, device=device)
738
+ ratio_stats = torch.zeros(stats_shape, device=device)
739
+ model.train()
740
+
741
+ # trainer state initialization
742
+ self.state.global_step = 0
743
+ self.state.episode = 0
744
+ self.state.max_steps = args.num_total_batches
745
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
746
+ # Compute absolute values for logging, eval, and save if given as ratio
747
+ if args.logging_steps is not None:
748
+ if args.logging_steps < 1:
749
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
750
+ else:
751
+ self.state.logging_steps = args.logging_steps
752
+ if args.eval_steps is not None:
753
+ if args.eval_steps < 1:
754
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
755
+ else:
756
+ self.state.eval_steps = args.eval_steps
757
+ if args.save_steps is not None:
758
+ if args.save_steps < 1:
759
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
760
+ else:
761
+ self.state.save_steps = args.save_steps
762
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
763
+
764
+ # backward compatibility
765
+ if self.is_deepspeed_enabled:
766
+ self.deepspeed = self.model
767
+ self.model_wrapped = self.model
768
+
769
+ for update in range(1, args.num_total_batches + 1):
770
+ self.state.episode += 1 * args.batch_size
771
+ data = next(iter_dataloader)
772
+ with torch.no_grad():
773
+ queries = data["input_ids"].to(device)
774
+ context_length = queries.shape[1]
775
+ responses = []
776
+ postprocessed_responses = []
777
+ logprobs = []
778
+ ref_logprobs = []
779
+ scores = []
780
+ sequence_lengths = []
781
+ values = []
782
+ with unwrap_model_for_generation(
783
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
784
+ ) as unwrapped_model:
785
+ query_responses, logitss = batch_generation(
786
+ unwrapped_model.policy,
787
+ queries,
788
+ args.local_rollout_forward_batch_size,
789
+ processing_class.pad_token_id,
790
+ generation_config,
791
+ )
792
+
793
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
794
+ query = queries[i : i + args.local_rollout_forward_batch_size]
795
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
796
+ response = query_response[:, context_length:]
797
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
798
+ logprob = selective_log_softmax(logits, response)
799
+ del logits
800
+ empty_cache()
801
+
802
+ if ref_policy is None:
803
+ with self.null_ref_context():
804
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
805
+ else:
806
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
807
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
808
+ ref_logits /= args.temperature + 1e-7
809
+ ref_logprob = selective_log_softmax(ref_logits, response)
810
+ del ref_output, ref_logits
811
+ empty_cache()
812
+
813
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
814
+ postprocessed_response = response
815
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
816
+ postprocessed_response = truncate_response(
817
+ self.stop_token_id, processing_class.pad_token_id, response
818
+ )
819
+
820
+ # Response Processing 2. run reward model on the truncated responses
821
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
822
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
823
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
824
+ full_value, _, _ = get_reward(
825
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
826
+ )
827
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
828
+ _, score, _ = get_reward(
829
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
830
+ )
831
+
832
+ responses.append(response)
833
+ postprocessed_responses.append(postprocessed_response)
834
+ logprobs.append(logprob)
835
+ ref_logprobs.append(ref_logprob)
836
+ sequence_lengths.append(sequence_length)
837
+ scores.append(score)
838
+ values.append(value)
839
+ responses = torch.cat(responses, 0)
840
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
841
+ logprobs = torch.cat(logprobs, 0)
842
+ ref_logprobs = torch.cat(ref_logprobs, 0)
843
+ sequence_lengths = torch.cat(sequence_lengths, 0)
844
+ scores = torch.cat(scores, 0)
845
+ values = torch.cat(values, 0)
846
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
847
+ empty_cache()
848
+ gc.collect()
849
+
850
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
851
+ # Completions not passing that filter will receive a lower score.
852
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
853
+ if self.args.missing_eos_penalty is not None:
854
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
855
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
856
+
857
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
858
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
859
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
860
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
861
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
862
+ sequence_lengths_p1 = sequence_lengths + 1
863
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
864
+ values = torch.masked_fill(values, padding_mask_p1, 0)
865
+
866
+ # 4. compute rewards
867
+ # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
868
+ logr = ref_logprobs - logprobs
869
+ kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
870
+ non_score_reward = -args.kl_coef * kl
871
+ rewards = non_score_reward.clone()
872
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
873
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
874
+ rewards[[actual_start, actual_end]] += scores
875
+
876
+ # 5. whiten rewards
877
+ if args.whiten_rewards:
878
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
879
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
880
+
881
+ # 6. compute advantages and returns
882
+ lastgaelam = 0
883
+ advantages_reversed = []
884
+ gen_length = responses.shape[1]
885
+ for t in reversed(range(gen_length)):
886
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
887
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
888
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
889
+ advantages_reversed.append(lastgaelam)
890
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
891
+ returns = advantages + values
892
+ advantages = masked_whiten(advantages, ~padding_mask)
893
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
894
+ empty_cache()
895
+
896
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
897
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
898
+ b_inds = np.random.permutation(args.local_batch_size)
899
+ minibatch_idx = 0
900
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
901
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
902
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
903
+ gradient_accumulation_idx = 0
904
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
905
+ with accelerator.accumulate(model):
906
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
907
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
908
+ mb_advantage = advantages[micro_batch_inds]
909
+ mb_responses = responses[micro_batch_inds]
910
+ mb_query_responses = query_responses[micro_batch_inds]
911
+ mb_logprobs = logprobs[micro_batch_inds]
912
+ mb_return = returns[micro_batch_inds]
913
+ mb_values = values[micro_batch_inds]
914
+
915
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
916
+ logits = output.logits[:, context_length - 1 : -1]
917
+ logits /= args.temperature + 1e-7
918
+ new_logprobs = selective_log_softmax(logits, mb_responses)
919
+ new_logprobs = torch.masked_fill(
920
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
921
+ )
922
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
923
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
924
+ vpredclipped = torch.clamp(
925
+ vpred,
926
+ mb_values - args.cliprange_value,
927
+ mb_values + args.cliprange_value,
928
+ )
929
+ vf_losses1 = torch.square(vpred - mb_return)
930
+ vf_losses2 = torch.square(vpredclipped - mb_return)
931
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
932
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
933
+ vf_clipfrac = masked_mean(
934
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
935
+ )
936
+ logprobs_diff = new_logprobs - mb_logprobs
937
+ ratio = torch.exp(logprobs_diff)
938
+ pg_losses = -mb_advantage * ratio
939
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
940
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
941
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
942
+ loss = pg_loss + args.vf_coef * vf_loss
943
+ accelerator.backward(loss)
944
+ optimizer.step()
945
+ optimizer.zero_grad()
946
+ with torch.no_grad():
947
+ pg_clipfrac = masked_mean(
948
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
949
+ )
950
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype)
951
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
952
+ approxkl = 0.5 * (logprobs_diff**2).mean()
953
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
954
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
955
+ pg_clipfrac
956
+ )
957
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
958
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
959
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
960
+ vf_clipfrac
961
+ )
962
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
963
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
964
+ gradient_accumulation_idx += 1
965
+ minibatch_idx += 1
966
+ # del everything and empty cache
967
+ # fmt: off
968
+ del (
969
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
970
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
971
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
972
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
973
+ )
974
+ # fmt: on
975
+ empty_cache()
976
+ with torch.no_grad():
977
+ mean_kl = kl.sum(1).mean()
978
+ mean_entropy = (-logprobs).sum(1).mean()
979
+ mean_non_score_reward = non_score_reward.sum(1).mean()
980
+ rlhf_reward = mean_non_score_reward + scores.mean()
981
+ eps = int(self.state.episode / (time.time() - start_time))
982
+ metrics = {}
983
+ metrics["eps"] = eps
984
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
985
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
986
+ metrics["objective/non_score_reward"] = (
987
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
988
+ )
989
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
990
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
991
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
992
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
993
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
994
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
995
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
996
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
997
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
998
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
999
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
1000
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
1001
+ metrics["episode"] = self.state.episode
1002
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
1003
+ self.state.global_step += 1
1004
+ self.log(metrics)
1005
+
1006
+ self.lr_scheduler.step()
1007
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
1008
+ if self.control.should_save:
1009
+ self._save_checkpoint(model, trial=None)
1010
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1011
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
1012
+ empty_cache()
1013
+ gc.collect()
1014
+
1015
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
1016
+ self.generate_completions(sampling=True)
1017
+ empty_cache()
1018
+ del (
1019
+ query_responses,
1020
+ responses,
1021
+ postprocessed_responses,
1022
+ logprobs,
1023
+ ref_logprobs,
1024
+ values,
1025
+ sequence_lengths,
1026
+ contain_eos_token,
1027
+ sequence_lengths_p1,
1028
+ response_idxs,
1029
+ padding_mask,
1030
+ padding_mask_p1,
1031
+ rewards,
1032
+ actual_start,
1033
+ actual_end,
1034
+ advantages,
1035
+ returns,
1036
+ )
1037
+ empty_cache()
1038
+
1039
+ # HF trainer specifics
1040
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1041
+ if self.control.should_save:
1042
+ self._save_checkpoint(model, trial=None, metrics=None)
1043
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1044
+
1045
+ def generate_completions(self, sampling: bool = False):
1046
+ args = self.args
1047
+ processing_class = self.processing_class
1048
+ generation_config = GenerationConfig(
1049
+ max_new_tokens=self.args.response_length,
1050
+ temperature=(0.01 + 1e-7),
1051
+ top_k=0.0,
1052
+ top_p=1.0,
1053
+ do_sample=True,
1054
+ )
1055
+
1056
+ table = defaultdict(list)
1057
+ with unwrap_model_for_generation(
1058
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1059
+ ) as unwrapped_model:
1060
+ for batch in self.eval_dataloader:
1061
+ query = batch["input_ids"]
1062
+ with torch.no_grad():
1063
+ context_length = query.shape[1]
1064
+ query_response, _ = batch_generation(
1065
+ unwrapped_model.policy,
1066
+ query,
1067
+ query.shape[0],
1068
+ processing_class.pad_token_id,
1069
+ generation_config,
1070
+ )
1071
+ response = query_response[:, context_length:]
1072
+ postprocessed_response = response
1073
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1074
+ postprocessed_response = truncate_response(
1075
+ self.stop_token_id, processing_class.pad_token_id, response
1076
+ )
1077
+ table["query"].extend(
1078
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1079
+ )
1080
+ table["model response"].extend(
1081
+ gather_object(processing_class.batch_decode(postprocessed_response))
1082
+ )
1083
+
1084
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1085
+ _, score, _ = get_reward(
1086
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1087
+ )
1088
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1089
+
1090
+ if sampling:
1091
+ break
1092
+ df = pd.DataFrame(table)
1093
+
1094
+ if self.accelerator.is_main_process:
1095
+ if is_rich_available():
1096
+ print_rich_table(df.iloc[0 : 0 + 5])
1097
+ if "wandb" in args.report_to:
1098
+ import wandb
1099
+
1100
+ if wandb.run is not None:
1101
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1102
+
1103
+ if "comet_ml" in args.report_to:
1104
+ log_table_to_comet_experiment(
1105
+ name="completions.csv",
1106
+ table=df,
1107
+ )
1108
+
1109
+ # Ensure the model card is saved along with the checkpoint
1110
+ def _save_checkpoint(self, model, trial):
1111
+ if self.args.hub_model_id is None:
1112
+ model_name = Path(self.args.output_dir).name
1113
+ else:
1114
+ model_name = self.args.hub_model_id.split("/")[-1]
1115
+ self.create_model_card(model_name=model_name)
1116
+ super()._save_checkpoint(model, trial)
1117
+
1118
+ def create_model_card(
1119
+ self,
1120
+ model_name: Optional[str] = None,
1121
+ dataset_name: Optional[str] = None,
1122
+ tags: Union[str, list[str], None] = None,
1123
+ ):
1124
+ """
1125
+ Creates a draft of a model card using the information available to the `Trainer`.
1126
+
1127
+ Args:
1128
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1129
+ Name of the model.
1130
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1131
+ Name of the dataset used for training.
1132
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1133
+ Tags to be associated with the model card.
1134
+ """
1135
+ if not self.is_world_process_zero():
1136
+ return
1137
+
1138
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1139
+ base_model = self.model.config._name_or_path
1140
+ else:
1141
+ base_model = None
1142
+
1143
+ # normalize `tags` to a mutable set
1144
+ if tags is None:
1145
+ tags = set()
1146
+ elif isinstance(tags, str):
1147
+ tags = {tags}
1148
+ else:
1149
+ tags = set(tags)
1150
+
1151
+ if hasattr(self.model.config, "unsloth_version"):
1152
+ tags.add("unsloth")
1153
+
1154
+ tags.update(self._tag_names)
1155
+
1156
+ citation = textwrap.dedent("""\
1157
+ @article{mziegler2019fine-tuning,
1158
+ title = {{Fine-Tuning Language Models from Human Preferences}},
1159
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1160
+ year = 2019,
1161
+ eprint = {arXiv:1909.08593}
1162
+ }""")
1163
+
1164
+ model_card = generate_model_card(
1165
+ base_model=base_model,
1166
+ model_name=model_name,
1167
+ hub_model_id=self.hub_model_id,
1168
+ dataset_name=dataset_name,
1169
+ tags=tags,
1170
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1171
+ comet_url=get_comet_experiment_url(),
1172
+ trainer_name="PPO",
1173
+ trainer_citation=citation,
1174
+ paper_title="Fine-Tuning Language Models from Human Preferences",
1175
+ paper_id="1909.08593",
1176
+ )
1177
+
1178
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1179
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1180
+ """
1181
+
1182
+ """
1183
+ def __init__(
1184
+ self,
1185
+ args,
1186
+ processing_class,
1187
+ model,
1188
+ ref_model,
1189
+ reward_model,
1190
+ train_dataset,
1191
+ value_model,
1192
+ data_collator = None,
1193
+ eval_dataset = None,
1194
+ callbacks = None,
1195
+ peft_config = None,
1196
+ **kwargs
1197
+ ):
1198
+ if args is None: args = UnslothPPOConfig()
1199
+ use_bf16 = getattr(args, 'bf16', False)
1200
+ if type(use_bf16) is not bool: use_bf16 = False
1201
+ use_fp16 = getattr(args, 'fp16', False)
1202
+ if type(use_fp16) is not bool: use_fp16 = False
1203
+ force_float32 = False
1204
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1205
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1206
+ force_float32 = True
1207
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1208
+ dtype = getattr(model.config, 'torch_dtype', None)
1209
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1210
+ from unsloth_zoo.utils import _get_dtype
1211
+ dtype = _get_dtype(dtype)
1212
+ float16 = dtype == torch.float16
1213
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1214
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1215
+ if force_float32:
1216
+ args.fp16 = False
1217
+ args.bf16 = False
1218
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1219
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1220
+ args.fp16 = float16
1221
+ args.bf16 = not float16
1222
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1223
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1224
+ args.eval_strategy = 'steps'
1225
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1226
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1227
+ if ga_steps is not None and ga_steps > 1:
1228
+ from transformers import __version__ as transformers_version
1229
+ if Version(transformers_version) <= Version('4.45.2'):
1230
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1231
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1232
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1233
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1234
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1235
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1236
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1237
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1238
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1239
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1240
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1241
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1242
+ if force_float32:
1243
+ args.bf16_full_eval = False
1244
+ args.fp16_full_eval = False
1245
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1246
+ args.bf16_full_eval = True
1247
+ args.fp16_full_eval = False
1248
+ elif not bf16_full_eval and not fp16_full_eval:
1249
+ args.bf16_full_eval = args.bf16
1250
+ args.fp16_full_eval = args.fp16
1251
+ _output_logits = False
1252
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1253
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1254
+ if _output_logits:
1255
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1256
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1257
+ pass
1258
+ else:
1259
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1260
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1261
+ if args_max_seq_length is None and model_max_seq_length is not None:
1262
+ max_seq_length = model.max_seq_length
1263
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1264
+ if model is not None and hasattr(model, 'for_training'):
1265
+ model.for_training()
1266
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1267
+ if 'processing_class' in locals():
1268
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1269
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1270
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1271
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1272
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1273
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1274
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1275
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1276
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1277
+ else:
1278
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1279
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1280
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1281
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1282
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1283
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1284
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1285
+ else:
1286
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1287
+ other_metrics = []
1288
+
1289
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1290
+ PatchRLStatistics('ppo_trainer', other_metrics)
1291
+
1292
+ super().__init__(
1293
+ args = args,
1294
+ processing_class = processing_class,
1295
+ model = model,
1296
+ ref_model = ref_model,
1297
+ reward_model = reward_model,
1298
+ train_dataset = train_dataset,
1299
+ value_model = value_model,
1300
+ data_collator = data_collator,
1301
+ eval_dataset = eval_dataset,
1302
+ callbacks = callbacks,
1303
+ peft_config = peft_config,**kwargs)
1304
+ if hasattr(self, 'neftune_hook_handle'):
1305
+ self.neftune_hook_handle.remove()
1306
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1307
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1308
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1309
+ pass
1310
+ if hasattr(self, 'accelerator'):
1311
+ scaler = self.accelerator.scaler
1312
+ current_model = model
1313
+ while hasattr(current_model, 'model'):
1314
+ current_model.accelerator_scaler = scaler
1315
+ current_model = current_model.model
1316
+ current_model.accelerator_scaler = scaler
1317
+ pass
1318
+
1319
+ pass
unsloth_compiled_cache/UnslothPRMTrainer.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothPRMConfig(PRMConfig):
53
+ """
54
+
55
+ Configuration class for the [`PRMTrainer`].
56
+
57
+ This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
59
+ differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
67
+ Maximum length of the sequences (prompt + completion) used for truncation.
68
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
69
+ Maximum length of the prompt used for truncation.
70
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
71
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
72
+ disable_dropout (`bool`, *optional*, defaults to `True`):
73
+ Whether to disable dropout in the model.
74
+ step_separator (`str`, *optional*, defaults to `"\n"`):
75
+ Separator used to separate each step of the reasoning process.
76
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
77
+ Whether to train only on the last step.
78
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
79
+ Number of processes to use for processing the dataset.
80
+
81
+ """
82
+ vllm_sampling_params: Optional[Any] = field(
83
+ default = None,
84
+ metadata = {'help': 'vLLM SamplingParams'},
85
+ )
86
+ unsloth_num_chunks : Optional[int] = field(
87
+ default = -1,
88
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
89
+ )
90
+ max_seq_length : Optional[int] = field(
91
+ default = None,
92
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
93
+ )
94
+ def __init__(
95
+ self,
96
+ output_dir = None,
97
+ overwrite_output_dir = None,
98
+ do_train = False,
99
+ do_eval = False,
100
+ do_predict = False,
101
+ eval_strategy = 'no',
102
+ prediction_loss_only = False,
103
+ per_device_train_batch_size = 4,
104
+ per_device_eval_batch_size = 4,
105
+ per_gpu_train_batch_size = None,
106
+ per_gpu_eval_batch_size = None,
107
+ gradient_accumulation_steps = 2,
108
+ eval_accumulation_steps = 2,
109
+ eval_delay = 0,
110
+ torch_empty_cache_steps = 250,
111
+ learning_rate = 5e-05,
112
+ weight_decay = 0.01,
113
+ adam_beta1 = 0.9,
114
+ adam_beta2 = 0.999,
115
+ adam_epsilon = 1e-08,
116
+ max_grad_norm = 1.0,
117
+ num_train_epochs = 3.0,
118
+ max_steps = -1,
119
+ lr_scheduler_type = 'linear',
120
+ warmup_ratio = 0.1,
121
+ warmup_steps = 0,
122
+ log_level = 'passive',
123
+ log_level_replica = 'warning',
124
+ log_on_each_node = True,
125
+ logging_dir = None,
126
+ logging_strategy = 'steps',
127
+ logging_first_step = False,
128
+ logging_steps = 1,
129
+ logging_nan_inf_filter = False,
130
+ save_strategy = 'steps',
131
+ save_steps = 500,
132
+ save_total_limit = None,
133
+ save_safetensors = True,
134
+ save_on_each_node = False,
135
+ save_only_model = False,
136
+ restore_callback_states_from_checkpoint = False,
137
+ no_cuda = False,
138
+ use_cpu = False,
139
+ use_mps_device = False,
140
+ seed = 3407,
141
+ data_seed = 3407,
142
+ jit_mode_eval = False,
143
+ use_ipex = False,
144
+ bf16 = False,
145
+ fp16 = False,
146
+ fp16_opt_level = 'O1',
147
+ half_precision_backend = 'auto',
148
+ bf16_full_eval = False,
149
+ fp16_full_eval = False,
150
+ tf32 = None,
151
+ local_rank = -1,
152
+ ddp_backend = None,
153
+ tpu_num_cores = None,
154
+ tpu_metrics_debug = False,
155
+ debug = '',
156
+ dataloader_drop_last = False,
157
+ eval_steps = None,
158
+ dataloader_num_workers = 0,
159
+ dataloader_prefetch_factor = None,
160
+ past_index = -1,
161
+ run_name = None,
162
+ disable_tqdm = None,
163
+ remove_unused_columns = True,
164
+ label_names = None,
165
+ load_best_model_at_end = False,
166
+ metric_for_best_model = None,
167
+ greater_is_better = None,
168
+ ignore_data_skip = False,
169
+ fsdp = '',
170
+ fsdp_min_num_params = 0,
171
+ fsdp_config = None,
172
+ fsdp_transformer_layer_cls_to_wrap = None,
173
+ accelerator_config = None,
174
+ deepspeed = None,
175
+ label_smoothing_factor = 0.0,
176
+ optim = 'adamw_8bit',
177
+ optim_args = None,
178
+ adafactor = False,
179
+ group_by_length = False,
180
+ length_column_name = 'length',
181
+ report_to = None,
182
+ ddp_find_unused_parameters = None,
183
+ ddp_bucket_cap_mb = None,
184
+ ddp_broadcast_buffers = None,
185
+ dataloader_pin_memory = True,
186
+ dataloader_persistent_workers = False,
187
+ skip_memory_metrics = True,
188
+ use_legacy_prediction_loop = False,
189
+ push_to_hub = False,
190
+ resume_from_checkpoint = None,
191
+ hub_model_id = None,
192
+ hub_strategy = 'every_save',
193
+ hub_token = None,
194
+ hub_private_repo = None,
195
+ hub_always_push = False,
196
+ hub_revision = None,
197
+ gradient_checkpointing = False,
198
+ gradient_checkpointing_kwargs = None,
199
+ include_inputs_for_metrics = False,
200
+ eval_do_concat_batches = True,
201
+ fp16_backend = 'auto',
202
+ push_to_hub_model_id = None,
203
+ push_to_hub_organization = None,
204
+ push_to_hub_token = None,
205
+ mp_parameters = '',
206
+ auto_find_batch_size = True,
207
+ full_determinism = False,
208
+ torchdynamo = None,
209
+ ray_scope = 'last',
210
+ ddp_timeout = 1800,
211
+ torch_compile = False,
212
+ torch_compile_backend = None,
213
+ torch_compile_mode = None,
214
+ include_tokens_per_second = False,
215
+ include_num_input_tokens_seen = False,
216
+ neftune_noise_alpha = None,
217
+ optim_target_modules = None,
218
+ batch_eval_metrics = False,
219
+ eval_on_start = False,
220
+ use_liger_kernel = False,
221
+ liger_kernel_config = None,
222
+ eval_use_gather_object = False,
223
+ average_tokens_across_devices = True,
224
+ max_length = 1024,
225
+ max_prompt_length = 512,
226
+ max_completion_length = None,
227
+ disable_dropout = True,
228
+ step_separator = '\
229
+ ',
230
+ train_on_last_step_only = False,
231
+ dataset_num_proc = None,
232
+ vllm_sampling_params = None,
233
+ unsloth_num_chunks = -1,
234
+ max_seq_length = None,
235
+ **kwargs,
236
+ ):
237
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
238
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
239
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
240
+ output_dir = 'unsloth_training_checkpoints'
241
+ save_strategy = 'no'
242
+ if dataset_num_proc is None:
243
+ from multiprocessing import cpu_count
244
+ dataset_num_proc = min(cpu_count()*2, 2)
245
+
246
+ super().__init__(
247
+ output_dir = output_dir,
248
+ overwrite_output_dir = overwrite_output_dir,
249
+ do_train = do_train,
250
+ do_eval = do_eval,
251
+ do_predict = do_predict,
252
+ eval_strategy = eval_strategy,
253
+ prediction_loss_only = prediction_loss_only,
254
+ per_device_train_batch_size = per_device_train_batch_size,
255
+ per_device_eval_batch_size = per_device_eval_batch_size,
256
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
257
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
258
+ gradient_accumulation_steps = gradient_accumulation_steps,
259
+ eval_accumulation_steps = eval_accumulation_steps,
260
+ eval_delay = eval_delay,
261
+ torch_empty_cache_steps = torch_empty_cache_steps,
262
+ learning_rate = learning_rate,
263
+ weight_decay = weight_decay,
264
+ adam_beta1 = adam_beta1,
265
+ adam_beta2 = adam_beta2,
266
+ adam_epsilon = adam_epsilon,
267
+ max_grad_norm = max_grad_norm,
268
+ num_train_epochs = num_train_epochs,
269
+ max_steps = max_steps,
270
+ lr_scheduler_type = lr_scheduler_type,
271
+ warmup_ratio = warmup_ratio,
272
+ warmup_steps = warmup_steps,
273
+ log_level = log_level,
274
+ log_level_replica = log_level_replica,
275
+ log_on_each_node = log_on_each_node,
276
+ logging_dir = logging_dir,
277
+ logging_strategy = logging_strategy,
278
+ logging_first_step = logging_first_step,
279
+ logging_steps = logging_steps,
280
+ logging_nan_inf_filter = logging_nan_inf_filter,
281
+ save_strategy = save_strategy,
282
+ save_steps = save_steps,
283
+ save_total_limit = save_total_limit,
284
+ save_safetensors = save_safetensors,
285
+ save_on_each_node = save_on_each_node,
286
+ save_only_model = save_only_model,
287
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
288
+ no_cuda = no_cuda,
289
+ use_cpu = use_cpu,
290
+ use_mps_device = use_mps_device,
291
+ seed = seed,
292
+ data_seed = data_seed,
293
+ jit_mode_eval = jit_mode_eval,
294
+ use_ipex = use_ipex,
295
+ bf16 = bf16,
296
+ fp16 = fp16,
297
+ fp16_opt_level = fp16_opt_level,
298
+ half_precision_backend = half_precision_backend,
299
+ bf16_full_eval = bf16_full_eval,
300
+ fp16_full_eval = fp16_full_eval,
301
+ tf32 = tf32,
302
+ local_rank = local_rank,
303
+ ddp_backend = ddp_backend,
304
+ tpu_num_cores = tpu_num_cores,
305
+ tpu_metrics_debug = tpu_metrics_debug,
306
+ debug = debug,
307
+ dataloader_drop_last = dataloader_drop_last,
308
+ eval_steps = eval_steps,
309
+ dataloader_num_workers = dataloader_num_workers,
310
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
311
+ past_index = past_index,
312
+ run_name = run_name,
313
+ disable_tqdm = disable_tqdm,
314
+ remove_unused_columns = remove_unused_columns,
315
+ label_names = label_names,
316
+ load_best_model_at_end = load_best_model_at_end,
317
+ metric_for_best_model = metric_for_best_model,
318
+ greater_is_better = greater_is_better,
319
+ ignore_data_skip = ignore_data_skip,
320
+ fsdp = fsdp,
321
+ fsdp_min_num_params = fsdp_min_num_params,
322
+ fsdp_config = fsdp_config,
323
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
324
+ accelerator_config = accelerator_config,
325
+ deepspeed = deepspeed,
326
+ label_smoothing_factor = label_smoothing_factor,
327
+ optim = optim,
328
+ optim_args = optim_args,
329
+ adafactor = adafactor,
330
+ group_by_length = group_by_length,
331
+ length_column_name = length_column_name,
332
+ report_to = report_to,
333
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
334
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
335
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
336
+ dataloader_pin_memory = dataloader_pin_memory,
337
+ dataloader_persistent_workers = dataloader_persistent_workers,
338
+ skip_memory_metrics = skip_memory_metrics,
339
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
340
+ push_to_hub = push_to_hub,
341
+ resume_from_checkpoint = resume_from_checkpoint,
342
+ hub_model_id = hub_model_id,
343
+ hub_strategy = hub_strategy,
344
+ hub_token = hub_token,
345
+ hub_private_repo = hub_private_repo,
346
+ hub_always_push = hub_always_push,
347
+ hub_revision = hub_revision,
348
+ gradient_checkpointing = gradient_checkpointing,
349
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
350
+ include_inputs_for_metrics = include_inputs_for_metrics,
351
+ eval_do_concat_batches = eval_do_concat_batches,
352
+ fp16_backend = fp16_backend,
353
+ push_to_hub_model_id = push_to_hub_model_id,
354
+ push_to_hub_organization = push_to_hub_organization,
355
+ push_to_hub_token = push_to_hub_token,
356
+ mp_parameters = mp_parameters,
357
+ auto_find_batch_size = auto_find_batch_size,
358
+ full_determinism = full_determinism,
359
+ torchdynamo = torchdynamo,
360
+ ray_scope = ray_scope,
361
+ ddp_timeout = ddp_timeout,
362
+ torch_compile = torch_compile,
363
+ torch_compile_backend = torch_compile_backend,
364
+ torch_compile_mode = torch_compile_mode,
365
+ include_tokens_per_second = include_tokens_per_second,
366
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
367
+ neftune_noise_alpha = neftune_noise_alpha,
368
+ optim_target_modules = optim_target_modules,
369
+ batch_eval_metrics = batch_eval_metrics,
370
+ eval_on_start = eval_on_start,
371
+ use_liger_kernel = use_liger_kernel,
372
+ liger_kernel_config = liger_kernel_config,
373
+ eval_use_gather_object = eval_use_gather_object,
374
+ average_tokens_across_devices = average_tokens_across_devices,
375
+ max_length = max_length,
376
+ max_prompt_length = max_prompt_length,
377
+ max_completion_length = max_completion_length,
378
+ disable_dropout = disable_dropout,
379
+ step_separator = step_separator,
380
+ train_on_last_step_only = train_on_last_step_only,
381
+ dataset_num_proc = dataset_num_proc,**kwargs)
382
+ self.vllm_sampling_params = vllm_sampling_params
383
+ self.unsloth_num_chunks = unsloth_num_chunks
384
+ self.max_seq_length = max_seq_length
385
+ pass
386
+
387
+ class _UnslothPRMTrainer(Trainer):
388
+ """"""
389
+
390
+ _tag_names = ["trl", "prm"]
391
+
392
+ def __init__(
393
+ self,
394
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
395
+ args: Optional[PRMConfig] = None,
396
+ data_collator: Optional[DataCollator] = None,
397
+ train_dataset: Optional[Dataset] = None,
398
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
399
+ processing_class: Optional[
400
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
401
+ ] = None,
402
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
403
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
404
+ callbacks: Optional[list[TrainerCallback]] = None,
405
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
406
+ None,
407
+ None,
408
+ ),
409
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
410
+ peft_config: Optional[dict] = None,
411
+ ):
412
+ if not is_peft_available() and peft_config is not None:
413
+ raise ValueError(
414
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
415
+ )
416
+ elif is_peft_available() and peft_config is not None:
417
+ if not isinstance(model, PeftModel):
418
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
419
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
420
+ inspect.signature(prepare_model_for_kbit_training).parameters
421
+ )
422
+
423
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
424
+
425
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
426
+ warnings.warn(
427
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
428
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
429
+ )
430
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
431
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
432
+
433
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
434
+
435
+ model = model
436
+
437
+ # Disable dropout in the model
438
+ if args.disable_dropout:
439
+ disable_dropout_in_model(model)
440
+
441
+ if compute_metrics is None:
442
+ compute_metrics = compute_accuracy
443
+
444
+ if data_collator is None:
445
+ if processing_class is None:
446
+ raise ValueError(
447
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
448
+ )
449
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
450
+
451
+ if "input_ids" not in train_dataset.column_names:
452
+ with PartialState().main_process_first():
453
+ fn_kwargs = {
454
+ "tokenizer": processing_class,
455
+ "step_separator": args.step_separator,
456
+ "max_length": args.max_length,
457
+ "max_prompt_length": args.max_prompt_length,
458
+ "max_completion_length": args.max_completion_length,
459
+ "train_on_last_step_only": args.train_on_last_step_only,
460
+ }
461
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
462
+ train_dataset = train_dataset.map(
463
+ self.tokenize_row,
464
+ fn_kwargs=train_fn_kwargs,
465
+ num_proc=args.dataset_num_proc,
466
+ remove_columns=train_dataset.features,
467
+ desc="Tokenizing train dataset",
468
+ features=features.Features( # needed to avoid map to cast labels to bool
469
+ {
470
+ "labels": features.Sequence(features.Value("int64")),
471
+ "input_ids": features.Sequence(features.Value("int64")),
472
+ }
473
+ ),
474
+ )
475
+
476
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
477
+ if eval_dataset is not None:
478
+ eval_dataset = eval_dataset.map(
479
+ self.tokenize_row,
480
+ fn_kwargs=eval_fn_kwargs,
481
+ num_proc=args.dataset_num_proc,
482
+ remove_columns=eval_dataset.features,
483
+ desc="Tokenizing eval dataset",
484
+ features=features.Features( # needed to avoid map to cast labels to bool
485
+ {
486
+ "labels": features.Sequence(features.Value("int64")),
487
+ "input_ids": features.Sequence(features.Value("int64")),
488
+ }
489
+ ),
490
+ )
491
+
492
+ super().__init__(
493
+ model=model,
494
+ args=args,
495
+ data_collator=data_collator,
496
+ train_dataset=train_dataset,
497
+ eval_dataset=eval_dataset,
498
+ processing_class=processing_class,
499
+ model_init=model_init,
500
+ compute_metrics=compute_metrics,
501
+ callbacks=callbacks,
502
+ optimizers=optimizers,
503
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
504
+ )
505
+
506
+ # Add tags for models that have been loaded with the correct transformers version
507
+ if hasattr(self.model, "add_model_tags"):
508
+ self.model.add_model_tags(self._tag_names)
509
+
510
+ @staticmethod
511
+ def tokenize_row(
512
+ features,
513
+ tokenizer,
514
+ step_separator,
515
+ max_length,
516
+ max_prompt_length,
517
+ max_completion_length,
518
+ train_on_last_step_only,
519
+ is_eval,
520
+ ):
521
+ r"""
522
+ Tokenize a row of the dataset.
523
+
524
+ Args:
525
+ features (`dict[str, str]`):
526
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
527
+ tokenizer (`PreTrainedTokenizerBase`):
528
+ Tokenizer used to process the data.
529
+ step_separator (`str`):
530
+ Separator between steps in the completion.
531
+ max_length (`int` or `None`):
532
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
533
+ max_prompt_length (`int` or `None`):
534
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
535
+ max_completion_length (`int` or `None`):
536
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
537
+ train_on_last_step_only (`bool`):
538
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
539
+ token of the completion.
540
+ is_eval (`bool`):
541
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if
542
+ `train_on_last_step_only` is set to `True`.
543
+
544
+ Returns:
545
+ `dict[str, list[int]]`:
546
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
547
+
548
+ Example:
549
+ ```python
550
+ >>> from transformers import AutoTokenizer
551
+
552
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
553
+ >>> features = {
554
+ ... "prompt": "Which number is larger, 9.8 or 9.11?",
555
+ ... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
556
+ ... "labels": [True, False],
557
+ ... }
558
+ >>> PRMTrainer.tokenize_row(
559
+ ... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False
560
+ ... )
561
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
562
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
563
+ ```
564
+ """
565
+ # Tokenize the prompt and completions
566
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
567
+ completions_ids = [
568
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
569
+ ]
570
+ if train_on_last_step_only and not is_eval:
571
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
572
+ else:
573
+ labels = [int(label) for label in features["labels"]]
574
+
575
+ # Get the ID of the separator token and add it to the completions
576
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
577
+ completions_ids = [completion + separator_ids for completion in completions_ids]
578
+
579
+ # Create the label
580
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
581
+
582
+ # Join the completions and labels steps
583
+ completion_ids = list(chain(*completions_ids))
584
+ labels = list(chain(*labels))
585
+
586
+ if tokenizer.bos_token_id is not None:
587
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
588
+
589
+ # Truncate prompt and completion sequences
590
+ if max_prompt_length is not None:
591
+ prompt_ids = prompt_ids[-max_prompt_length:]
592
+ if max_completion_length is not None:
593
+ completion_ids = completion_ids[:max_completion_length]
594
+ labels = labels[:max_completion_length]
595
+
596
+ input_ids = prompt_ids + completion_ids
597
+ labels = [-100] * len(prompt_ids) + labels
598
+
599
+ if max_length is not None:
600
+ input_ids = input_ids[:max_length]
601
+ labels = labels[:max_length]
602
+
603
+ return {"input_ids": input_ids, "labels": labels}
604
+
605
+ # Ensure the model card is saved along with the checkpoint
606
+ def _save_checkpoint(self, model, trial):
607
+ if self.args.hub_model_id is None:
608
+ model_name = Path(self.args.output_dir).name
609
+ else:
610
+ model_name = self.args.hub_model_id.split("/")[-1]
611
+ self.create_model_card(model_name=model_name)
612
+ super()._save_checkpoint(model, trial)
613
+
614
+ def create_model_card(
615
+ self,
616
+ model_name: Optional[str] = None,
617
+ dataset_name: Optional[str] = None,
618
+ tags: Union[str, list[str], None] = None,
619
+ ):
620
+ """
621
+ Creates a draft of a model card using the information available to the `Trainer`.
622
+
623
+ Args:
624
+ model_name (`str` or `None`, *optional*, defaults to `None`):
625
+ Name of the model.
626
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
627
+ Name of the dataset used for training.
628
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
629
+ Tags to be associated with the model card.
630
+ """
631
+ if not self.is_world_process_zero():
632
+ return
633
+
634
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
635
+ base_model = self.model.config._name_or_path
636
+ else:
637
+ base_model = None
638
+
639
+ # normalize `tags` to a mutable set
640
+ if tags is None:
641
+ tags = set()
642
+ elif isinstance(tags, str):
643
+ tags = {tags}
644
+ else:
645
+ tags = set(tags)
646
+
647
+ if hasattr(self.model.config, "unsloth_version"):
648
+ tags.add("unsloth")
649
+
650
+ tags.update(self._tag_names)
651
+
652
+ citation = textwrap.dedent("""\
653
+ @article{uesato2022solving,
654
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
655
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
656
+ year = 2022,
657
+ journal = {arXiv preprint arXiv:2211.14275}
658
+ }""")
659
+
660
+ model_card = generate_model_card(
661
+ base_model=base_model,
662
+ model_name=model_name,
663
+ hub_model_id=self.hub_model_id,
664
+ dataset_name=dataset_name,
665
+ tags=tags,
666
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
667
+ trainer_name="PRM",
668
+ trainer_citation=citation,
669
+ paper_title="Solving math word problems with process-and outcome-based feedback",
670
+ )
671
+
672
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
673
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
674
+ """
675
+
676
+ Initialize PRMTrainer.
677
+
678
+ Args:
679
+ model (`transformers.PreTrainedModel`):
680
+ The model to train, preferably an `AutoModelForTokenClassification`.
681
+ args (`PRMConfig`):
682
+ The arguments to use for training.
683
+ data_collator (`transformers.DataCollator`):
684
+ The data collator to use for training. If None is specified, the default data collator
685
+ (`DataCollatorForTokenClassification`) will be used which will pad the sequences to the maximum length of
686
+ the sequences in the batch, given a dataset of paired sequences.
687
+ train_dataset (`datasets.Dataset`):
688
+ The dataset to use for training.
689
+ eval_dataset (`datasets.Dataset`):
690
+ The dataset to use for evaluation.
691
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
692
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
693
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
694
+ reuse the fine-tuned model.
695
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
696
+ The model initializer to use for training. If None is specified, the default model initializer will be
697
+ used.
698
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
699
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
700
+ will be used.
701
+ callbacks (`list[transformers.TrainerCallback]`):
702
+ The callbacks to use for training.
703
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
704
+ The optimizer and scheduler to use for training.
705
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
706
+ The function to use to preprocess the logits before computing the metrics.
707
+ peft_config (`dict`, defaults to `None`):
708
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
709
+ a PEFT model.
710
+
711
+ """
712
+ def __init__(
713
+ self,
714
+ model = None,
715
+ args = None,
716
+ data_collator = None,
717
+ train_dataset = None,
718
+ eval_dataset = None,
719
+ processing_class = None,
720
+ model_init = None,
721
+ compute_metrics = None,
722
+ callbacks = None,
723
+ preprocess_logits_for_metrics = None,
724
+ peft_config = None,
725
+ **kwargs
726
+ ):
727
+ if args is None: args = UnslothPRMConfig()
728
+ use_bf16 = getattr(args, 'bf16', False)
729
+ if type(use_bf16) is not bool: use_bf16 = False
730
+ use_fp16 = getattr(args, 'fp16', False)
731
+ if type(use_fp16) is not bool: use_fp16 = False
732
+ force_float32 = False
733
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
734
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
735
+ force_float32 = True
736
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
737
+ dtype = getattr(model.config, 'torch_dtype', None)
738
+ if dtype is None: dtype = model.get_input_embeddings().dtype
739
+ from unsloth_zoo.utils import _get_dtype
740
+ dtype = _get_dtype(dtype)
741
+ float16 = dtype == torch.float16
742
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
743
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
744
+ if force_float32:
745
+ args.fp16 = False
746
+ args.bf16 = False
747
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
748
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
749
+ args.fp16 = float16
750
+ args.bf16 = not float16
751
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
752
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
753
+ args.eval_strategy = 'steps'
754
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
755
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
756
+ if ga_steps is not None and ga_steps > 1:
757
+ from transformers import __version__ as transformers_version
758
+ if Version(transformers_version) <= Version('4.45.2'):
759
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
760
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
761
+ if getattr(args, 'eval_strategy', 'no') != 'no':
762
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
763
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
764
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
765
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
766
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
767
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
768
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
769
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
770
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
771
+ if force_float32:
772
+ args.bf16_full_eval = False
773
+ args.fp16_full_eval = False
774
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
775
+ args.bf16_full_eval = True
776
+ args.fp16_full_eval = False
777
+ elif not bf16_full_eval and not fp16_full_eval:
778
+ args.bf16_full_eval = args.bf16
779
+ args.fp16_full_eval = args.fp16
780
+ _output_logits = False
781
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
782
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
783
+ if _output_logits:
784
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
785
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
786
+ pass
787
+ else:
788
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
789
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
790
+ if args_max_seq_length is None and model_max_seq_length is not None:
791
+ max_seq_length = model.max_seq_length
792
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
793
+ if model is not None and hasattr(model, 'for_training'):
794
+ model.for_training()
795
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
796
+ if 'processing_class' in locals():
797
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
798
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
799
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
800
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
801
+ if not isinstance(data_collator, UnslothVisionDataCollator):
802
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
803
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
804
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
805
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
806
+ else:
807
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
808
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
809
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
810
+ if not isinstance(data_collator, UnslothVisionDataCollator):
811
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
812
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
813
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
814
+ else:
815
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
816
+ other_metrics = []
817
+
818
+ from unsloth_zoo.logging_utils import PatchRLStatistics
819
+ PatchRLStatistics('prm_trainer', other_metrics)
820
+
821
+ super().__init__(
822
+ model = model,
823
+ args = args,
824
+ data_collator = data_collator,
825
+ train_dataset = train_dataset,
826
+ eval_dataset = eval_dataset,
827
+ processing_class = processing_class,
828
+ model_init = model_init,
829
+ compute_metrics = compute_metrics,
830
+ callbacks = callbacks,
831
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
832
+ peft_config = peft_config,**kwargs)
833
+ if hasattr(self, 'neftune_hook_handle'):
834
+ self.neftune_hook_handle.remove()
835
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
836
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
837
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
838
+ pass
839
+ if hasattr(self, 'accelerator'):
840
+ scaler = self.accelerator.scaler
841
+ current_model = model
842
+ while hasattr(current_model, 'model'):
843
+ current_model.accelerator_scaler = scaler
844
+ current_model = current_model.model
845
+ current_model.accelerator_scaler = scaler
846
+ pass
847
+
848
+ pass
unsloth_compiled_cache/UnslothRLOOTrainer.py ADDED
@@ -0,0 +1,1174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, Path, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_rich_available, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb, Optional, Trainer, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothRLOOConfig(RLOOConfig):
53
+ """
54
+
55
+ Configuration class for the [`RLOOTrainer`].
56
+
57
+ This class includes only the parameters that are specific to RLOO training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
59
+ values in this class may differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
67
+ Name of this experiment.
68
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
69
+ Path to the reward model.
70
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
71
+ Number of epochs to train.
72
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
73
+ Whether to whiten the rewards.
74
+ kl_coef (`float`, *optional*, defaults to `0.05`):
75
+ KL coefficient.
76
+ cliprange (`float`, *optional*, defaults to `0.2`):
77
+ Clip range.
78
+ rloo_k (`int`, *optional*, defaults to `2`):
79
+ REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
80
+ normalize_reward (`bool`, *optional*, defaults to `False`):
81
+ Whether to normalize rewards.
82
+ reward_clip_range (`float`, *optional*, defaults to `10.0`):
83
+ Clip range for rewards.
84
+ normalize_advantage (`bool`, *optional*, defaults to `False`):
85
+ Whether to normalize advantages.
86
+ token_level_kl (`bool`, *optional*, defaults to `True`):
87
+ Whether to use token-level KL penalty or sequence-level KL penalty.
88
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
89
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
90
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
91
+ capacity of a single GPU, albeit at the cost of slower generation.
92
+
93
+ """
94
+ vllm_sampling_params: Optional[Any] = field(
95
+ default = None,
96
+ metadata = {'help': 'vLLM SamplingParams'},
97
+ )
98
+ unsloth_num_chunks : Optional[int] = field(
99
+ default = -1,
100
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
101
+ )
102
+
103
+ def __init__(
104
+ self,
105
+ output_dir = None,
106
+ overwrite_output_dir = None,
107
+ do_train = False,
108
+ do_eval = False,
109
+ do_predict = False,
110
+ eval_strategy = 'no',
111
+ prediction_loss_only = False,
112
+ per_device_train_batch_size = 4,
113
+ per_device_eval_batch_size = 4,
114
+ per_gpu_train_batch_size = None,
115
+ per_gpu_eval_batch_size = None,
116
+ gradient_accumulation_steps = 2,
117
+ eval_accumulation_steps = 2,
118
+ eval_delay = 0,
119
+ torch_empty_cache_steps = 250,
120
+ learning_rate = 5e-05,
121
+ weight_decay = 0.01,
122
+ adam_beta1 = 0.9,
123
+ adam_beta2 = 0.999,
124
+ adam_epsilon = 1e-08,
125
+ max_grad_norm = 1.0,
126
+ num_train_epochs = 3.0,
127
+ max_steps = -1,
128
+ lr_scheduler_type = 'linear',
129
+ warmup_ratio = 0.1,
130
+ warmup_steps = 0,
131
+ log_level = 'passive',
132
+ log_level_replica = 'warning',
133
+ log_on_each_node = True,
134
+ logging_dir = None,
135
+ logging_strategy = 'steps',
136
+ logging_first_step = False,
137
+ logging_steps = 1,
138
+ logging_nan_inf_filter = False,
139
+ save_strategy = 'steps',
140
+ save_steps = 500,
141
+ save_total_limit = None,
142
+ save_safetensors = True,
143
+ save_on_each_node = False,
144
+ save_only_model = False,
145
+ restore_callback_states_from_checkpoint = False,
146
+ no_cuda = False,
147
+ use_cpu = False,
148
+ use_mps_device = False,
149
+ seed = 3407,
150
+ data_seed = 3407,
151
+ jit_mode_eval = False,
152
+ use_ipex = False,
153
+ bf16 = False,
154
+ fp16 = False,
155
+ fp16_opt_level = 'O1',
156
+ half_precision_backend = 'auto',
157
+ bf16_full_eval = False,
158
+ fp16_full_eval = False,
159
+ tf32 = None,
160
+ local_rank = -1,
161
+ ddp_backend = None,
162
+ tpu_num_cores = None,
163
+ tpu_metrics_debug = False,
164
+ debug = '',
165
+ dataloader_drop_last = False,
166
+ eval_steps = None,
167
+ dataloader_num_workers = 0,
168
+ dataloader_prefetch_factor = None,
169
+ past_index = -1,
170
+ run_name = None,
171
+ disable_tqdm = None,
172
+ remove_unused_columns = True,
173
+ label_names = None,
174
+ load_best_model_at_end = False,
175
+ metric_for_best_model = None,
176
+ greater_is_better = None,
177
+ ignore_data_skip = False,
178
+ fsdp = '',
179
+ fsdp_min_num_params = 0,
180
+ fsdp_config = None,
181
+ fsdp_transformer_layer_cls_to_wrap = None,
182
+ accelerator_config = None,
183
+ deepspeed = None,
184
+ label_smoothing_factor = 0.0,
185
+ optim = 'adamw_8bit',
186
+ optim_args = None,
187
+ adafactor = False,
188
+ group_by_length = False,
189
+ length_column_name = 'length',
190
+ report_to = None,
191
+ ddp_find_unused_parameters = None,
192
+ ddp_bucket_cap_mb = None,
193
+ ddp_broadcast_buffers = None,
194
+ dataloader_pin_memory = True,
195
+ dataloader_persistent_workers = False,
196
+ skip_memory_metrics = True,
197
+ use_legacy_prediction_loop = False,
198
+ push_to_hub = False,
199
+ resume_from_checkpoint = None,
200
+ hub_model_id = None,
201
+ hub_strategy = 'every_save',
202
+ hub_token = None,
203
+ hub_private_repo = None,
204
+ hub_always_push = False,
205
+ hub_revision = None,
206
+ gradient_checkpointing = False,
207
+ gradient_checkpointing_kwargs = None,
208
+ include_inputs_for_metrics = False,
209
+ eval_do_concat_batches = True,
210
+ fp16_backend = 'auto',
211
+ push_to_hub_model_id = None,
212
+ push_to_hub_organization = None,
213
+ push_to_hub_token = None,
214
+ mp_parameters = '',
215
+ auto_find_batch_size = True,
216
+ full_determinism = False,
217
+ torchdynamo = None,
218
+ ray_scope = 'last',
219
+ ddp_timeout = 1800,
220
+ torch_compile = False,
221
+ torch_compile_backend = None,
222
+ torch_compile_mode = None,
223
+ include_tokens_per_second = False,
224
+ include_num_input_tokens_seen = False,
225
+ neftune_noise_alpha = None,
226
+ optim_target_modules = None,
227
+ batch_eval_metrics = False,
228
+ eval_on_start = False,
229
+ use_liger_kernel = False,
230
+ liger_kernel_config = None,
231
+ eval_use_gather_object = False,
232
+ average_tokens_across_devices = True,
233
+ dataset_num_proc = None,
234
+ num_mini_batches = 1,
235
+ total_episodes = None,
236
+ local_rollout_forward_batch_size = 64,
237
+ num_sample_generations = 10,
238
+ response_length = 53,
239
+ stop_token = None,
240
+ stop_token_id = None,
241
+ temperature = 0.7,
242
+ missing_eos_penalty = None,
243
+ sft_model_path = 'EleutherAI/pythia-160m',
244
+ world_size = None,
245
+ num_total_batches = None,
246
+ micro_batch_size = None,
247
+ local_batch_size = None,
248
+ batch_size = None,
249
+ local_mini_batch_size = None,
250
+ mini_batch_size = None,
251
+ exp_name = 'rloo_config',
252
+ reward_model_path = 'EleutherAI/pythia-160m',
253
+ num_ppo_epochs = 4,
254
+ whiten_rewards = False,
255
+ kl_coef = 0.05,
256
+ cliprange = 0.2,
257
+ rloo_k = 2,
258
+ normalize_reward = False,
259
+ reward_clip_range = 10.0,
260
+ normalize_advantage = False,
261
+ token_level_kl = False,
262
+ ds3_gather_for_generation = True,
263
+ vllm_sampling_params = None,
264
+ unsloth_num_chunks = -1,
265
+
266
+ **kwargs,
267
+ ):
268
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
269
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
270
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
271
+ output_dir = 'unsloth_training_checkpoints'
272
+ save_strategy = 'no'
273
+ if dataset_num_proc is None:
274
+ from multiprocessing import cpu_count
275
+ dataset_num_proc = min(cpu_count()*2, 2)
276
+ if temperature <= 0:
277
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
278
+ elif temperature >= 10:
279
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
280
+
281
+
282
+ super().__init__(
283
+ output_dir = output_dir,
284
+ overwrite_output_dir = overwrite_output_dir,
285
+ do_train = do_train,
286
+ do_eval = do_eval,
287
+ do_predict = do_predict,
288
+ eval_strategy = eval_strategy,
289
+ prediction_loss_only = prediction_loss_only,
290
+ per_device_train_batch_size = per_device_train_batch_size,
291
+ per_device_eval_batch_size = per_device_eval_batch_size,
292
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
293
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
294
+ gradient_accumulation_steps = gradient_accumulation_steps,
295
+ eval_accumulation_steps = eval_accumulation_steps,
296
+ eval_delay = eval_delay,
297
+ torch_empty_cache_steps = torch_empty_cache_steps,
298
+ learning_rate = learning_rate,
299
+ weight_decay = weight_decay,
300
+ adam_beta1 = adam_beta1,
301
+ adam_beta2 = adam_beta2,
302
+ adam_epsilon = adam_epsilon,
303
+ max_grad_norm = max_grad_norm,
304
+ num_train_epochs = num_train_epochs,
305
+ max_steps = max_steps,
306
+ lr_scheduler_type = lr_scheduler_type,
307
+ warmup_ratio = warmup_ratio,
308
+ warmup_steps = warmup_steps,
309
+ log_level = log_level,
310
+ log_level_replica = log_level_replica,
311
+ log_on_each_node = log_on_each_node,
312
+ logging_dir = logging_dir,
313
+ logging_strategy = logging_strategy,
314
+ logging_first_step = logging_first_step,
315
+ logging_steps = logging_steps,
316
+ logging_nan_inf_filter = logging_nan_inf_filter,
317
+ save_strategy = save_strategy,
318
+ save_steps = save_steps,
319
+ save_total_limit = save_total_limit,
320
+ save_safetensors = save_safetensors,
321
+ save_on_each_node = save_on_each_node,
322
+ save_only_model = save_only_model,
323
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
324
+ no_cuda = no_cuda,
325
+ use_cpu = use_cpu,
326
+ use_mps_device = use_mps_device,
327
+ seed = seed,
328
+ data_seed = data_seed,
329
+ jit_mode_eval = jit_mode_eval,
330
+ use_ipex = use_ipex,
331
+ bf16 = bf16,
332
+ fp16 = fp16,
333
+ fp16_opt_level = fp16_opt_level,
334
+ half_precision_backend = half_precision_backend,
335
+ bf16_full_eval = bf16_full_eval,
336
+ fp16_full_eval = fp16_full_eval,
337
+ tf32 = tf32,
338
+ local_rank = local_rank,
339
+ ddp_backend = ddp_backend,
340
+ tpu_num_cores = tpu_num_cores,
341
+ tpu_metrics_debug = tpu_metrics_debug,
342
+ debug = debug,
343
+ dataloader_drop_last = dataloader_drop_last,
344
+ eval_steps = eval_steps,
345
+ dataloader_num_workers = dataloader_num_workers,
346
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
347
+ past_index = past_index,
348
+ run_name = run_name,
349
+ disable_tqdm = disable_tqdm,
350
+ remove_unused_columns = remove_unused_columns,
351
+ label_names = label_names,
352
+ load_best_model_at_end = load_best_model_at_end,
353
+ metric_for_best_model = metric_for_best_model,
354
+ greater_is_better = greater_is_better,
355
+ ignore_data_skip = ignore_data_skip,
356
+ fsdp = fsdp,
357
+ fsdp_min_num_params = fsdp_min_num_params,
358
+ fsdp_config = fsdp_config,
359
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
360
+ accelerator_config = accelerator_config,
361
+ deepspeed = deepspeed,
362
+ label_smoothing_factor = label_smoothing_factor,
363
+ optim = optim,
364
+ optim_args = optim_args,
365
+ adafactor = adafactor,
366
+ group_by_length = group_by_length,
367
+ length_column_name = length_column_name,
368
+ report_to = report_to,
369
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
370
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
371
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
372
+ dataloader_pin_memory = dataloader_pin_memory,
373
+ dataloader_persistent_workers = dataloader_persistent_workers,
374
+ skip_memory_metrics = skip_memory_metrics,
375
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
376
+ push_to_hub = push_to_hub,
377
+ resume_from_checkpoint = resume_from_checkpoint,
378
+ hub_model_id = hub_model_id,
379
+ hub_strategy = hub_strategy,
380
+ hub_token = hub_token,
381
+ hub_private_repo = hub_private_repo,
382
+ hub_always_push = hub_always_push,
383
+ hub_revision = hub_revision,
384
+ gradient_checkpointing = gradient_checkpointing,
385
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
386
+ include_inputs_for_metrics = include_inputs_for_metrics,
387
+ eval_do_concat_batches = eval_do_concat_batches,
388
+ fp16_backend = fp16_backend,
389
+ push_to_hub_model_id = push_to_hub_model_id,
390
+ push_to_hub_organization = push_to_hub_organization,
391
+ push_to_hub_token = push_to_hub_token,
392
+ mp_parameters = mp_parameters,
393
+ auto_find_batch_size = auto_find_batch_size,
394
+ full_determinism = full_determinism,
395
+ torchdynamo = torchdynamo,
396
+ ray_scope = ray_scope,
397
+ ddp_timeout = ddp_timeout,
398
+ torch_compile = torch_compile,
399
+ torch_compile_backend = torch_compile_backend,
400
+ torch_compile_mode = torch_compile_mode,
401
+ include_tokens_per_second = include_tokens_per_second,
402
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
403
+ neftune_noise_alpha = neftune_noise_alpha,
404
+ optim_target_modules = optim_target_modules,
405
+ batch_eval_metrics = batch_eval_metrics,
406
+ eval_on_start = eval_on_start,
407
+ use_liger_kernel = use_liger_kernel,
408
+ liger_kernel_config = liger_kernel_config,
409
+ eval_use_gather_object = eval_use_gather_object,
410
+ average_tokens_across_devices = average_tokens_across_devices,
411
+ dataset_num_proc = dataset_num_proc,
412
+ num_mini_batches = num_mini_batches,
413
+ total_episodes = total_episodes,
414
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
415
+ num_sample_generations = num_sample_generations,
416
+ response_length = response_length,
417
+ stop_token = stop_token,
418
+ stop_token_id = stop_token_id,
419
+ temperature = temperature,
420
+ missing_eos_penalty = missing_eos_penalty,
421
+ sft_model_path = sft_model_path,
422
+ world_size = world_size,
423
+ num_total_batches = num_total_batches,
424
+ micro_batch_size = micro_batch_size,
425
+ local_batch_size = local_batch_size,
426
+ batch_size = batch_size,
427
+ local_mini_batch_size = local_mini_batch_size,
428
+ mini_batch_size = mini_batch_size,
429
+ exp_name = exp_name,
430
+ reward_model_path = reward_model_path,
431
+ num_ppo_epochs = num_ppo_epochs,
432
+ whiten_rewards = whiten_rewards,
433
+ kl_coef = kl_coef,
434
+ cliprange = cliprange,
435
+ rloo_k = rloo_k,
436
+ normalize_reward = normalize_reward,
437
+ reward_clip_range = reward_clip_range,
438
+ normalize_advantage = normalize_advantage,
439
+ token_level_kl = token_level_kl,
440
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
441
+ self.vllm_sampling_params = vllm_sampling_params
442
+ self.unsloth_num_chunks = unsloth_num_chunks
443
+
444
+ pass
445
+
446
+ class _UnslothRLOOTrainer(Trainer):
447
+ _tag_names = ["trl", "rloo"]
448
+
449
+ def __init__(
450
+ self,
451
+ config: RLOOConfig,
452
+ processing_class: Optional[
453
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
454
+ ],
455
+ policy: nn.Module,
456
+ ref_policy: nn.Module,
457
+ reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
458
+ train_dataset: Dataset,
459
+ data_collator: Optional[DataCollatorWithPadding] = None,
460
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
461
+ # less commonly used
462
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
463
+ callbacks: Optional[list[TrainerCallback]] = None,
464
+ ) -> None:
465
+ if ref_policy is policy:
466
+ raise ValueError(
467
+ "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
468
+ "same as `policy`, you must mass a copy of it, or `None` if you use peft."
469
+ )
470
+
471
+ self.args = config
472
+ args = config
473
+ self.processing_class = processing_class
474
+ self.policy = policy
475
+
476
+ # Define the collator if not provided
477
+ if data_collator is None:
478
+ data_collator = DataCollatorWithPadding(self.processing_class)
479
+
480
+ self.policy.generation_config.eos_token_id = (
481
+ None # disable `pad_token_id` and `eos_token_id` because we just want to
482
+ )
483
+ self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
484
+
485
+ self.ref_policy = ref_policy
486
+ self.reward_model = reward_model
487
+ self.train_dataset = train_dataset
488
+ self.train_dataset_len = len(train_dataset)
489
+ self.data_collator = data_collator
490
+ self.eval_dataset = eval_dataset
491
+ self.optimizer, self.lr_scheduler = optimizers
492
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
493
+
494
+ #########
495
+ # calculate various batch sizes
496
+ #########
497
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
498
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
499
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
500
+ self.accelerator = accelerator
501
+ args.world_size = accelerator.num_processes
502
+ args.local_batch_size = (
503
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
504
+ )
505
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
506
+ args.batch_size = int(args.local_batch_size * args.world_size)
507
+ args.mini_batch_size = exact_div(
508
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
509
+ )
510
+ args.local_mini_batch_size = exact_div(
511
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
512
+ )
513
+ args.num_total_batches = math.ceil(
514
+ args.total_episodes / args.batch_size
515
+ ) # we may train for more than `total_episodes`
516
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
517
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
518
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
519
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
520
+ if args.num_sample_generations > 0:
521
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
522
+ self.local_dataloader_batch_size = exact_div(
523
+ args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
524
+ ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
525
+
526
+ #########
527
+ # setup model, optimizer, and others
528
+ #########
529
+ for module in [policy, ref_policy, reward_model]:
530
+ if isinstance(module, nn.Module):
531
+ disable_dropout_in_model(module)
532
+ if args.stop_token and args.stop_token == "eos":
533
+ args.stop_token_id = self.processing_class.eos_token_id
534
+ self.model = policy
535
+ self.create_optimizer_and_scheduler(
536
+ num_training_steps=args.num_total_batches
537
+ ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
538
+
539
+ #########
540
+ ### trainer specifics
541
+ #########
542
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
543
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
544
+ self.callback_handler = CallbackHandler(
545
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
546
+ )
547
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
548
+ self.control = TrainerControl()
549
+ self.state = OnlineTrainerState(
550
+ is_local_process_zero=self.is_local_process_zero(),
551
+ is_world_process_zero=self.is_world_process_zero(),
552
+ stateful_callbacks=[
553
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
554
+ ],
555
+ )
556
+
557
+ self.current_flos = 0
558
+ self.hp_search_backend = None
559
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
560
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
561
+ # Create distant repo and output directory if needed
562
+ self.hub_model_id = None
563
+ if self.args.push_to_hub:
564
+ self.init_hf_repo()
565
+ if self.args.should_save:
566
+ os.makedirs(self.args.output_dir, exist_ok=True)
567
+ self.backup_model = None
568
+
569
+ # Add tags for models that have been loaded with the correct transformers version
570
+ if hasattr(self.model, "add_model_tags"):
571
+ self.model.add_model_tags(self._tag_names)
572
+
573
+ #########
574
+ ### setup dataloader
575
+ #########
576
+ self.dataloader = DataLoader(
577
+ self.train_dataset,
578
+ batch_size=self.local_dataloader_batch_size,
579
+ shuffle=True,
580
+ collate_fn=self.data_collator,
581
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
582
+ )
583
+ # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
584
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
585
+ torch.manual_seed(args.seed)
586
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
587
+ torch.manual_seed(self.local_seed) # reset the local seed again
588
+
589
+ self.eval_dataloader = DataLoader(
590
+ self.eval_dataset,
591
+ batch_size=args.per_device_eval_batch_size,
592
+ collate_fn=self.data_collator,
593
+ drop_last=True,
594
+ ) # no need to shuffle eval dataset
595
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
596
+
597
+ if self.is_deepspeed_enabled:
598
+ if isinstance(self.reward_model, nn.Module):
599
+ self.reward_model = prepare_deepspeed(
600
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
601
+ )
602
+ self.ref_policy = prepare_deepspeed(
603
+ self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
604
+ )
605
+ self.deepspeed = self.model
606
+ else:
607
+ self.ref_policy = self.ref_policy.to(self.accelerator.device)
608
+ if isinstance(self.reward_model, nn.Module):
609
+ self.reward_model = self.reward_model.to(self.accelerator.device)
610
+
611
+ def get_train_dataloader(self) -> DataLoader:
612
+ return self.dataloader
613
+
614
+ def get_eval_dataloader(self) -> DataLoader:
615
+ return self.eval_dataloader
616
+
617
+ def train(self):
618
+ args = self.args
619
+ accelerator = self.accelerator
620
+ optimizer = self.optimizer
621
+ model = self.model
622
+ self.model_wrapped = self.model
623
+ ref_policy = self.ref_policy
624
+ reward_model = self.reward_model
625
+ processing_class = self.processing_class
626
+ dataloader = self.dataloader
627
+ device = accelerator.device
628
+
629
+ def repeat_generator():
630
+ while True:
631
+ yield from dataloader
632
+
633
+ iter_dataloader = iter(repeat_generator())
634
+ generation_config = GenerationConfig(
635
+ max_new_tokens=args.response_length,
636
+ temperature=(args.temperature + 1e-7),
637
+ top_k=0.0,
638
+ top_p=1.0,
639
+ do_sample=True,
640
+ )
641
+
642
+ accelerator.print("===training policy===")
643
+ start_time = time.time()
644
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
645
+ approxkl_stats = torch.zeros(stats_shape, device=device)
646
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
647
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
648
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
649
+ entropy_stats = torch.zeros(stats_shape, device=device)
650
+ ratio_stats = torch.zeros(stats_shape, device=device)
651
+ model.train()
652
+
653
+ # trainer state initialization
654
+ self.state.global_step = 0
655
+ self.state.episode = 0
656
+ self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
657
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
658
+ # Compute absolute values for logging, eval, and save if given as ratio
659
+ if args.logging_steps is not None:
660
+ if args.logging_steps < 1:
661
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
662
+ else:
663
+ self.state.logging_steps = args.logging_steps
664
+ if args.eval_steps is not None:
665
+ if args.eval_steps < 1:
666
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
667
+ else:
668
+ self.state.eval_steps = args.eval_steps
669
+ if args.save_steps is not None:
670
+ if args.save_steps < 1:
671
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
672
+ else:
673
+ self.state.save_steps = args.save_steps
674
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
675
+
676
+ for update in range(1, args.num_total_batches + 1):
677
+ self.state.episode += 1 * args.batch_size
678
+ data = next(iter_dataloader)
679
+ with torch.no_grad():
680
+ queries = data["input_ids"].to(device)
681
+ queries = queries.repeat(args.rloo_k, 1)
682
+ context_length = queries.shape[1]
683
+ responses = []
684
+ postprocessed_responses = []
685
+ logprobs = []
686
+ ref_logprobs = []
687
+ scores = []
688
+ sequence_lengths = []
689
+
690
+ # Generate responses and compute logprobs
691
+ with unwrap_model_for_generation(
692
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
693
+ ) as unwrapped_model:
694
+ query_responses, logitss = batch_generation(
695
+ unwrapped_model,
696
+ queries,
697
+ args.local_rollout_forward_batch_size,
698
+ processing_class.pad_token_id,
699
+ generation_config,
700
+ )
701
+
702
+ # Process responses in batches
703
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
704
+ query = queries[i : i + args.local_rollout_forward_batch_size]
705
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
706
+ response = query_response[:, context_length:]
707
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
708
+ logprob = selective_log_softmax(logits, response)
709
+ del logits
710
+ empty_cache()
711
+
712
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
713
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
714
+ ref_logits /= args.temperature + 1e-7
715
+ ref_logprob = selective_log_softmax(ref_logits, response)
716
+ del ref_output, ref_logits
717
+ empty_cache()
718
+
719
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
720
+ postprocessed_response = response
721
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
722
+ postprocessed_response = truncate_response(
723
+ args.stop_token_id, processing_class.pad_token_id, response
724
+ )
725
+
726
+ # Response Processing 2. run reward model on the truncated responses
727
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
728
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
729
+
730
+ if isinstance(reward_model, nn.Module):
731
+ _, score, _ = get_reward(
732
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
733
+ )
734
+ else:
735
+ score = torch.tensor(
736
+ reward_model(
737
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
738
+ ),
739
+ dtype=torch.float,
740
+ ).to(device)
741
+
742
+ # Store batch results
743
+ responses.append(response)
744
+ postprocessed_responses.append(postprocessed_response)
745
+ logprobs.append(logprob)
746
+ ref_logprobs.append(ref_logprob)
747
+ sequence_lengths.append(sequence_length)
748
+ scores.append(score)
749
+
750
+ # Concatenate all batched results
751
+ responses = torch.cat(responses, 0)
752
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
753
+ logprobs = torch.cat(logprobs, 0)
754
+ ref_logprobs = torch.cat(ref_logprobs, 0)
755
+ sequence_lengths = torch.cat(sequence_lengths, 0)
756
+ scores = torch.cat(scores, 0)
757
+ del (logprob, ref_logprob, score)
758
+ empty_cache()
759
+ gc.collect()
760
+
761
+ # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
762
+ # responses not passing that filter will receive a low (fixed) score
763
+ # only query humans on responses that pass that filter
764
+ contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
765
+ if args.missing_eos_penalty is not None:
766
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
767
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
768
+
769
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
770
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
771
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
772
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
773
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
774
+
775
+ # 4. compute rewards
776
+ # Compute KL divergence
777
+ kl = logprobs - ref_logprobs
778
+
779
+ # Normalize rewards
780
+ if args.normalize_reward:
781
+ scores = (scores - scores.mean()) / (scores.std() + 1e-8)
782
+ scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
783
+
784
+ # Compute total reward with KL penalty
785
+ if args.token_level_kl:
786
+ # Token-level KL penalty: apply KL penalty per token
787
+ kl_reward = -args.kl_coef * kl
788
+
789
+ # Get the index of the last non-padded token for each sequence
790
+ eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
791
+ last_reward = torch.zeros_like(kl)
792
+ # Ensure scores has correct shape and type
793
+ scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
794
+ last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
795
+
796
+ # Combine KL reward and last reward
797
+ non_score_reward = kl_reward.sum(1) # Keep this for logging
798
+ reward = last_reward + kl_reward
799
+ rlhf_reward = reward.sum(1) # Sum across sequence length
800
+ else:
801
+ # Sequence-level KL penalty: sum KL across tokens first
802
+ sequence_kl = kl.sum(1)
803
+ non_score_reward = -args.kl_coef * sequence_kl
804
+ rlhf_reward = non_score_reward + scores
805
+
806
+ # vectorized RLOO advantages implementation
807
+ rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
808
+ baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
809
+ advantages = rlhf_reward - baseline
810
+ advantages = advantages.flatten()
811
+
812
+ # Normalize advantages
813
+ if args.normalize_advantage:
814
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
815
+
816
+ empty_cache()
817
+
818
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
819
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
820
+ b_inds = np.random.permutation(args.local_batch_size)
821
+ minibatch_idx = 0
822
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
823
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
824
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
825
+ gradient_accumulation_idx = 0
826
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
827
+ with accelerator.accumulate(model):
828
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
829
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
830
+
831
+ # Get batch data
832
+ mb_advantage = advantages[micro_batch_inds]
833
+ mb_responses = responses[micro_batch_inds]
834
+ mb_query_responses = query_responses[micro_batch_inds]
835
+ mb_logprobs = logprobs[micro_batch_inds]
836
+
837
+ # Forward pass
838
+ output = forward(model, mb_query_responses, processing_class.pad_token_id)
839
+ logits = output.logits[:, context_length - 1 : -1]
840
+ logits /= args.temperature + 1e-7
841
+
842
+ # Compute new logprobs
843
+ new_logprobs = selective_log_softmax(logits, mb_responses)
844
+ new_logprobs = torch.masked_fill(
845
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
846
+ )
847
+
848
+ # Compute probability ratios
849
+ new_ratio = (new_logprobs - mb_logprobs).exp()
850
+ new_logprobs = new_logprobs.sum(1)
851
+ mb_logprobs = mb_logprobs.sum(1)
852
+ logprobs_diff = new_logprobs - mb_logprobs
853
+ ratio = torch.exp(logprobs_diff)
854
+
855
+ # PPO clipped loss
856
+ pg_losses = -mb_advantage * ratio
857
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
858
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
859
+ pg_loss = pg_loss_max.mean()
860
+
861
+ # Final loss
862
+ loss = pg_loss
863
+
864
+ # Optimization step
865
+ accelerator.backward(loss)
866
+ optimizer.step()
867
+ optimizer.zero_grad()
868
+
869
+ with torch.no_grad():
870
+ pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
871
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype)
872
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
873
+ approxkl = 0.5 * (logprobs_diff**2).mean()
874
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
875
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
876
+ pg_clipfrac
877
+ )
878
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
879
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
880
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
881
+ gradient_accumulation_idx += 1
882
+ minibatch_idx += 1
883
+
884
+ # del everything and empty cache
885
+ # fmt: off
886
+ del (
887
+ output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
888
+ pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
889
+ mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
890
+ )
891
+ # fmt: on
892
+ empty_cache()
893
+
894
+ # Compute metrics
895
+ with torch.no_grad():
896
+ mean_kl = kl.sum(1).mean()
897
+ mean_entropy = (-logprobs).sum(1).mean()
898
+ mean_non_score_reward = non_score_reward.mean()
899
+ eps = int(self.state.episode / (time.time() - start_time))
900
+ metrics = {}
901
+ metrics["eps"] = eps
902
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
903
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
904
+ metrics["objective/non_score_reward"] = (
905
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
906
+ )
907
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
908
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
909
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
910
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
911
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
912
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
913
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
914
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
915
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
916
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
917
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
918
+ metrics["episode"] = self.state.episode
919
+ self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
920
+ self.log(metrics)
921
+ del kl, mean_kl, mean_entropy, scores
922
+
923
+ self.lr_scheduler.step()
924
+ self.state.global_step += 1
925
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
926
+ if self.control.should_save:
927
+ self._save_checkpoint(model, trial=None)
928
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
929
+ empty_cache()
930
+ gc.collect()
931
+
932
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
933
+ self.generate_completions(sampling=True)
934
+
935
+ # HF trainer specifics
936
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
937
+ if self.control.should_save:
938
+ self._save_checkpoint(model, trial=None, metrics=None)
939
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
940
+
941
+ def generate_completions(self, sampling: bool = False):
942
+ args = self.args
943
+ processing_class = self.processing_class
944
+ generation_config = GenerationConfig(
945
+ max_new_tokens=self.args.response_length,
946
+ temperature=(0.01 + 1e-7),
947
+ top_k=0.0,
948
+ top_p=1.0,
949
+ do_sample=True,
950
+ )
951
+
952
+ table = defaultdict(list)
953
+ with unwrap_model_for_generation(
954
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
955
+ ) as unwrapped_model:
956
+ for batch in self.eval_dataloader:
957
+ query = batch["input_ids"]
958
+ with torch.no_grad():
959
+ context_length = query.shape[1]
960
+ query_response, _ = batch_generation(
961
+ unwrapped_model,
962
+ query,
963
+ query.shape[0],
964
+ processing_class.pad_token_id,
965
+ generation_config,
966
+ )
967
+ response = query_response[:, context_length:]
968
+ postprocessed_response = response
969
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
970
+ postprocessed_response = truncate_response(
971
+ args.stop_token_id, processing_class.pad_token_id, response
972
+ )
973
+ table["query"].extend(
974
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
975
+ )
976
+ table["model response"].extend(
977
+ gather_object(processing_class.batch_decode(postprocessed_response))
978
+ )
979
+
980
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
981
+
982
+ if isinstance(self.reward_model, nn.Module):
983
+ _, score, _ = get_reward(
984
+ self.reward_model,
985
+ postprocessed_query_response,
986
+ processing_class.pad_token_id,
987
+ context_length,
988
+ )
989
+ else:
990
+ score = torch.tensor(
991
+ self.reward_model(
992
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
993
+ ),
994
+ dtype=torch.float,
995
+ ).to(postprocessed_query_response.device)
996
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
997
+
998
+ if sampling:
999
+ break
1000
+ df = pd.DataFrame(table)
1001
+
1002
+ if self.accelerator.is_main_process:
1003
+ if is_rich_available():
1004
+ print_rich_table(df.iloc[0 : 0 + 5])
1005
+ if "wandb" in args.report_to:
1006
+ import wandb
1007
+
1008
+ if wandb.run is not None:
1009
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1010
+
1011
+ if "comet_ml" in args.report_to:
1012
+ log_table_to_comet_experiment(
1013
+ name="completions.csv",
1014
+ table=df,
1015
+ )
1016
+
1017
+ # Ensure the model card is saved along with the checkpoint
1018
+ def _save_checkpoint(self, model, trial):
1019
+ if self.args.hub_model_id is None:
1020
+ model_name = Path(self.args.output_dir).name
1021
+ else:
1022
+ model_name = self.args.hub_model_id.split("/")[-1]
1023
+ self.create_model_card(model_name=model_name)
1024
+ super()._save_checkpoint(model, trial)
1025
+
1026
+ def create_model_card(
1027
+ self,
1028
+ model_name: Optional[str] = None,
1029
+ dataset_name: Optional[str] = None,
1030
+ tags: Union[str, list[str], None] = None,
1031
+ ):
1032
+ """
1033
+ Creates a draft of a model card using the information available to the `Trainer`.
1034
+
1035
+ Args:
1036
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1037
+ Name of the model.
1038
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1039
+ Name of the dataset used for training.
1040
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1041
+ Tags to be associated with the model card.
1042
+ """
1043
+ if not self.is_world_process_zero():
1044
+ return
1045
+
1046
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1047
+ base_model = self.model.config._name_or_path
1048
+ else:
1049
+ base_model = None
1050
+
1051
+ # normalize `tags` to a mutable set
1052
+ if tags is None:
1053
+ tags = set()
1054
+ elif isinstance(tags, str):
1055
+ tags = {tags}
1056
+ else:
1057
+ tags = set(tags)
1058
+
1059
+ if hasattr(self.model.config, "unsloth_version"):
1060
+ tags.add("unsloth")
1061
+
1062
+ tags.update(self._tag_names)
1063
+
1064
+ citation = textwrap.dedent("""\
1065
+ @inproceedings{ahmadian2024back,
1066
+ title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
1067
+ author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
1068
+ year = 2024,
1069
+ booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
1070
+ publisher = {Association for Computational Linguistics},
1071
+ pages = {12248--12267},
1072
+ editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
1073
+ }""")
1074
+
1075
+ model_card = generate_model_card(
1076
+ base_model=base_model,
1077
+ model_name=model_name,
1078
+ hub_model_id=self.hub_model_id,
1079
+ dataset_name=dataset_name,
1080
+ tags=tags,
1081
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1082
+ comet_url=get_comet_experiment_url(),
1083
+ trainer_name="RLOO",
1084
+ trainer_citation=citation,
1085
+ paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
1086
+ paper_id="2402.14740",
1087
+ )
1088
+
1089
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1090
+ class UnslothRLOOTrainer(_UnslothRLOOTrainer):
1091
+ """
1092
+
1093
+ """
1094
+ def __init__(
1095
+ self,
1096
+ config,
1097
+ processing_class,
1098
+ policy,
1099
+ ref_policy,
1100
+ reward_model,
1101
+ train_dataset,
1102
+ data_collator = None,
1103
+ eval_dataset = None,
1104
+ callbacks = None,
1105
+ **kwargs
1106
+ ):
1107
+ if args is None: args = UnslothRLOOConfig()
1108
+ _output_logits = False
1109
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1110
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1111
+ if _output_logits:
1112
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1113
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1114
+ pass
1115
+ else:
1116
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1117
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1118
+ if args_max_seq_length is None and model_max_seq_length is not None:
1119
+ max_seq_length = model.max_seq_length
1120
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1121
+ if model is not None and hasattr(model, 'for_training'):
1122
+ model.for_training()
1123
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1124
+ if 'processing_class' in locals():
1125
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1126
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1127
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1128
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1129
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1130
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1131
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1132
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1133
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1134
+ else:
1135
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1136
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1137
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1138
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1139
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1140
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1141
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1142
+ else:
1143
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1144
+ other_metrics = []
1145
+
1146
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1147
+ PatchRLStatistics('rloo_trainer', other_metrics)
1148
+
1149
+ super().__init__(
1150
+ config = config,
1151
+ processing_class = processing_class,
1152
+ policy = policy,
1153
+ ref_policy = ref_policy,
1154
+ reward_model = reward_model,
1155
+ train_dataset = train_dataset,
1156
+ data_collator = data_collator,
1157
+ eval_dataset = eval_dataset,
1158
+ callbacks = callbacks,**kwargs)
1159
+ if hasattr(self, 'neftune_hook_handle'):
1160
+ self.neftune_hook_handle.remove()
1161
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1162
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1163
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1164
+ pass
1165
+ if hasattr(self, 'accelerator'):
1166
+ scaler = self.accelerator.scaler
1167
+ current_model = model
1168
+ while hasattr(current_model, 'model'):
1169
+ current_model.accelerator_scaler = scaler
1170
+ current_model = current_model.model
1171
+ current_model.accelerator_scaler = scaler
1172
+ pass
1173
+
1174
+ pass
unsloth_compiled_cache/UnslothRewardTrainer.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_rich_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothRewardConfig(RewardConfig):
53
+ """
54
+
55
+ Configuration class for the [`RewardTrainer`].
56
+
57
+ This class includes only the parameters that are specific to Reward training. For a full list of training
58
+ arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
59
+ class may differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
67
+ Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
68
+ limit. This argument is required if you want to use the default data collator.
69
+ disable_dropout (`bool`, *optional*, defaults to `True`):
70
+ Whether to disable dropout in the model.
71
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
72
+ Number of processes to use for processing the dataset.
73
+ center_rewards_coefficient (`float`, *optional*, defaults to `None`):
74
+ Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
75
+ https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
76
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
77
+ Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the
78
+ dataset is pretokenized.
79
+
80
+ """
81
+ vllm_sampling_params: Optional[Any] = field(
82
+ default = None,
83
+ metadata = {'help': 'vLLM SamplingParams'},
84
+ )
85
+ unsloth_num_chunks : Optional[int] = field(
86
+ default = -1,
87
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
+ )
89
+ max_seq_length : Optional[int] = field(
90
+ default = None,
91
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
92
+ )
93
+ def __init__(
94
+ self,
95
+ output_dir = None,
96
+ overwrite_output_dir = None,
97
+ do_train = False,
98
+ do_eval = False,
99
+ do_predict = False,
100
+ eval_strategy = 'no',
101
+ prediction_loss_only = False,
102
+ per_device_train_batch_size = 4,
103
+ per_device_eval_batch_size = 4,
104
+ per_gpu_train_batch_size = None,
105
+ per_gpu_eval_batch_size = None,
106
+ gradient_accumulation_steps = 2,
107
+ eval_accumulation_steps = 2,
108
+ eval_delay = 0,
109
+ torch_empty_cache_steps = 250,
110
+ learning_rate = 5e-05,
111
+ weight_decay = 0.01,
112
+ adam_beta1 = 0.9,
113
+ adam_beta2 = 0.999,
114
+ adam_epsilon = 1e-08,
115
+ max_grad_norm = 1.0,
116
+ num_train_epochs = 3.0,
117
+ max_steps = -1,
118
+ lr_scheduler_type = 'linear',
119
+ warmup_ratio = 0.1,
120
+ warmup_steps = 0,
121
+ log_level = 'passive',
122
+ log_level_replica = 'warning',
123
+ log_on_each_node = True,
124
+ logging_dir = None,
125
+ logging_strategy = 'steps',
126
+ logging_first_step = False,
127
+ logging_steps = 1,
128
+ logging_nan_inf_filter = False,
129
+ save_strategy = 'steps',
130
+ save_steps = 500,
131
+ save_total_limit = None,
132
+ save_safetensors = True,
133
+ save_on_each_node = False,
134
+ save_only_model = False,
135
+ restore_callback_states_from_checkpoint = False,
136
+ no_cuda = False,
137
+ use_cpu = False,
138
+ use_mps_device = False,
139
+ seed = 3407,
140
+ data_seed = 3407,
141
+ jit_mode_eval = False,
142
+ use_ipex = False,
143
+ bf16 = False,
144
+ fp16 = False,
145
+ fp16_opt_level = 'O1',
146
+ half_precision_backend = 'auto',
147
+ bf16_full_eval = False,
148
+ fp16_full_eval = False,
149
+ tf32 = None,
150
+ local_rank = -1,
151
+ ddp_backend = None,
152
+ tpu_num_cores = None,
153
+ tpu_metrics_debug = False,
154
+ debug = '',
155
+ dataloader_drop_last = False,
156
+ eval_steps = None,
157
+ dataloader_num_workers = 0,
158
+ dataloader_prefetch_factor = None,
159
+ past_index = -1,
160
+ run_name = None,
161
+ disable_tqdm = None,
162
+ remove_unused_columns = False,
163
+ label_names = None,
164
+ load_best_model_at_end = False,
165
+ metric_for_best_model = None,
166
+ greater_is_better = None,
167
+ ignore_data_skip = False,
168
+ fsdp = '',
169
+ fsdp_min_num_params = 0,
170
+ fsdp_config = None,
171
+ fsdp_transformer_layer_cls_to_wrap = None,
172
+ accelerator_config = None,
173
+ deepspeed = None,
174
+ label_smoothing_factor = 0.0,
175
+ optim = 'adamw_8bit',
176
+ optim_args = None,
177
+ adafactor = False,
178
+ group_by_length = False,
179
+ length_column_name = 'length',
180
+ report_to = None,
181
+ ddp_find_unused_parameters = None,
182
+ ddp_bucket_cap_mb = None,
183
+ ddp_broadcast_buffers = None,
184
+ dataloader_pin_memory = True,
185
+ dataloader_persistent_workers = False,
186
+ skip_memory_metrics = True,
187
+ use_legacy_prediction_loop = False,
188
+ push_to_hub = False,
189
+ resume_from_checkpoint = None,
190
+ hub_model_id = None,
191
+ hub_strategy = 'every_save',
192
+ hub_token = None,
193
+ hub_private_repo = None,
194
+ hub_always_push = False,
195
+ hub_revision = None,
196
+ gradient_checkpointing = False,
197
+ gradient_checkpointing_kwargs = None,
198
+ include_inputs_for_metrics = False,
199
+ eval_do_concat_batches = True,
200
+ fp16_backend = 'auto',
201
+ push_to_hub_model_id = None,
202
+ push_to_hub_organization = None,
203
+ push_to_hub_token = None,
204
+ mp_parameters = '',
205
+ auto_find_batch_size = True,
206
+ full_determinism = False,
207
+ torchdynamo = None,
208
+ ray_scope = 'last',
209
+ ddp_timeout = 1800,
210
+ torch_compile = False,
211
+ torch_compile_backend = None,
212
+ torch_compile_mode = None,
213
+ include_tokens_per_second = False,
214
+ include_num_input_tokens_seen = False,
215
+ neftune_noise_alpha = None,
216
+ optim_target_modules = None,
217
+ batch_eval_metrics = False,
218
+ eval_on_start = False,
219
+ use_liger_kernel = False,
220
+ liger_kernel_config = None,
221
+ eval_use_gather_object = False,
222
+ average_tokens_across_devices = True,
223
+ max_length = 1024,
224
+ disable_dropout = True,
225
+ dataset_num_proc = None,
226
+ center_rewards_coefficient = None,
227
+ vllm_sampling_params = None,
228
+ unsloth_num_chunks = -1,
229
+ max_seq_length = None,
230
+ **kwargs,
231
+ ):
232
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
233
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
234
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
235
+ output_dir = 'unsloth_training_checkpoints'
236
+ save_strategy = 'no'
237
+ if dataset_num_proc is None:
238
+ from multiprocessing import cpu_count
239
+ dataset_num_proc = min(cpu_count()*2, 2)
240
+
241
+ super().__init__(
242
+ output_dir = output_dir,
243
+ overwrite_output_dir = overwrite_output_dir,
244
+ do_train = do_train,
245
+ do_eval = do_eval,
246
+ do_predict = do_predict,
247
+ eval_strategy = eval_strategy,
248
+ prediction_loss_only = prediction_loss_only,
249
+ per_device_train_batch_size = per_device_train_batch_size,
250
+ per_device_eval_batch_size = per_device_eval_batch_size,
251
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
252
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
253
+ gradient_accumulation_steps = gradient_accumulation_steps,
254
+ eval_accumulation_steps = eval_accumulation_steps,
255
+ eval_delay = eval_delay,
256
+ torch_empty_cache_steps = torch_empty_cache_steps,
257
+ learning_rate = learning_rate,
258
+ weight_decay = weight_decay,
259
+ adam_beta1 = adam_beta1,
260
+ adam_beta2 = adam_beta2,
261
+ adam_epsilon = adam_epsilon,
262
+ max_grad_norm = max_grad_norm,
263
+ num_train_epochs = num_train_epochs,
264
+ max_steps = max_steps,
265
+ lr_scheduler_type = lr_scheduler_type,
266
+ warmup_ratio = warmup_ratio,
267
+ warmup_steps = warmup_steps,
268
+ log_level = log_level,
269
+ log_level_replica = log_level_replica,
270
+ log_on_each_node = log_on_each_node,
271
+ logging_dir = logging_dir,
272
+ logging_strategy = logging_strategy,
273
+ logging_first_step = logging_first_step,
274
+ logging_steps = logging_steps,
275
+ logging_nan_inf_filter = logging_nan_inf_filter,
276
+ save_strategy = save_strategy,
277
+ save_steps = save_steps,
278
+ save_total_limit = save_total_limit,
279
+ save_safetensors = save_safetensors,
280
+ save_on_each_node = save_on_each_node,
281
+ save_only_model = save_only_model,
282
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
283
+ no_cuda = no_cuda,
284
+ use_cpu = use_cpu,
285
+ use_mps_device = use_mps_device,
286
+ seed = seed,
287
+ data_seed = data_seed,
288
+ jit_mode_eval = jit_mode_eval,
289
+ use_ipex = use_ipex,
290
+ bf16 = bf16,
291
+ fp16 = fp16,
292
+ fp16_opt_level = fp16_opt_level,
293
+ half_precision_backend = half_precision_backend,
294
+ bf16_full_eval = bf16_full_eval,
295
+ fp16_full_eval = fp16_full_eval,
296
+ tf32 = tf32,
297
+ local_rank = local_rank,
298
+ ddp_backend = ddp_backend,
299
+ tpu_num_cores = tpu_num_cores,
300
+ tpu_metrics_debug = tpu_metrics_debug,
301
+ debug = debug,
302
+ dataloader_drop_last = dataloader_drop_last,
303
+ eval_steps = eval_steps,
304
+ dataloader_num_workers = dataloader_num_workers,
305
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
306
+ past_index = past_index,
307
+ run_name = run_name,
308
+ disable_tqdm = disable_tqdm,
309
+ remove_unused_columns = remove_unused_columns,
310
+ label_names = label_names,
311
+ load_best_model_at_end = load_best_model_at_end,
312
+ metric_for_best_model = metric_for_best_model,
313
+ greater_is_better = greater_is_better,
314
+ ignore_data_skip = ignore_data_skip,
315
+ fsdp = fsdp,
316
+ fsdp_min_num_params = fsdp_min_num_params,
317
+ fsdp_config = fsdp_config,
318
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
319
+ accelerator_config = accelerator_config,
320
+ deepspeed = deepspeed,
321
+ label_smoothing_factor = label_smoothing_factor,
322
+ optim = optim,
323
+ optim_args = optim_args,
324
+ adafactor = adafactor,
325
+ group_by_length = group_by_length,
326
+ length_column_name = length_column_name,
327
+ report_to = report_to,
328
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
329
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
330
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
331
+ dataloader_pin_memory = dataloader_pin_memory,
332
+ dataloader_persistent_workers = dataloader_persistent_workers,
333
+ skip_memory_metrics = skip_memory_metrics,
334
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
335
+ push_to_hub = push_to_hub,
336
+ resume_from_checkpoint = resume_from_checkpoint,
337
+ hub_model_id = hub_model_id,
338
+ hub_strategy = hub_strategy,
339
+ hub_token = hub_token,
340
+ hub_private_repo = hub_private_repo,
341
+ hub_always_push = hub_always_push,
342
+ hub_revision = hub_revision,
343
+ gradient_checkpointing = gradient_checkpointing,
344
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
345
+ include_inputs_for_metrics = include_inputs_for_metrics,
346
+ eval_do_concat_batches = eval_do_concat_batches,
347
+ fp16_backend = fp16_backend,
348
+ push_to_hub_model_id = push_to_hub_model_id,
349
+ push_to_hub_organization = push_to_hub_organization,
350
+ push_to_hub_token = push_to_hub_token,
351
+ mp_parameters = mp_parameters,
352
+ auto_find_batch_size = auto_find_batch_size,
353
+ full_determinism = full_determinism,
354
+ torchdynamo = torchdynamo,
355
+ ray_scope = ray_scope,
356
+ ddp_timeout = ddp_timeout,
357
+ torch_compile = torch_compile,
358
+ torch_compile_backend = torch_compile_backend,
359
+ torch_compile_mode = torch_compile_mode,
360
+ include_tokens_per_second = include_tokens_per_second,
361
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
362
+ neftune_noise_alpha = neftune_noise_alpha,
363
+ optim_target_modules = optim_target_modules,
364
+ batch_eval_metrics = batch_eval_metrics,
365
+ eval_on_start = eval_on_start,
366
+ use_liger_kernel = use_liger_kernel,
367
+ liger_kernel_config = liger_kernel_config,
368
+ eval_use_gather_object = eval_use_gather_object,
369
+ average_tokens_across_devices = average_tokens_across_devices,
370
+ max_length = max_length,
371
+ disable_dropout = disable_dropout,
372
+ dataset_num_proc = dataset_num_proc,
373
+ center_rewards_coefficient = center_rewards_coefficient,**kwargs)
374
+ self.vllm_sampling_params = vllm_sampling_params
375
+ self.unsloth_num_chunks = unsloth_num_chunks
376
+ self.max_seq_length = max_seq_length
377
+ pass
378
+
379
+ class _UnslothRewardTrainer(Trainer):
380
+ _tag_names = ["trl", "reward-trainer"]
381
+
382
+ def __init__(
383
+ self,
384
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
385
+ args: Optional[RewardConfig] = None,
386
+ data_collator: Optional[DataCollator] = None,
387
+ train_dataset: Optional[Dataset] = None,
388
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
389
+ processing_class: Optional[
390
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
391
+ ] = None,
392
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
393
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
394
+ callbacks: Optional[list[TrainerCallback]] = None,
395
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
396
+ None,
397
+ None,
398
+ ),
399
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
400
+ peft_config: Optional[dict] = None,
401
+ ):
402
+ """
403
+ Initialize RewardTrainer.
404
+
405
+ Args:
406
+ model (`transformers.PreTrainedModel`):
407
+ The model to train, preferably an `AutoModelForSequenceClassification`.
408
+ args (`RewardConfig`):
409
+ The arguments to use for training.
410
+ data_collator (`transformers.DataCollator`):
411
+ The data collator to use for training. If None is specified, the default data collator
412
+ (`RewardDataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of
413
+ the sequences in the batch, given a dataset of paired sequences.
414
+ train_dataset (`datasets.Dataset`):
415
+ The dataset to use for training.
416
+ eval_dataset (`datasets.Dataset`):
417
+ The dataset to use for evaluation.
418
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
419
+ Processing class used to process the data. If provided, will be used to automatically process the
420
+ inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted
421
+ training or reuse the fine-tuned model.
422
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
423
+ The model initializer to use for training. If None is specified, the default model initializer will be
424
+ used.
425
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
426
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
427
+ will be used.
428
+ callbacks (`list[transformers.TrainerCallback]`):
429
+ The callbacks to use for training.
430
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
431
+ The optimizer and scheduler to use for training.
432
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
433
+ The function to use to preprocess the logits before computing the metrics.
434
+ peft_config (`dict`, defaults to `None`):
435
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped
436
+ in a PEFT model.
437
+ """
438
+ if not is_peft_available() and peft_config is not None:
439
+ raise ValueError(
440
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
441
+ )
442
+ elif is_peft_available() and peft_config is not None:
443
+ if not isinstance(model, PeftModel):
444
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
445
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
446
+ inspect.signature(prepare_model_for_kbit_training).parameters
447
+ )
448
+
449
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
450
+
451
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
452
+ warnings.warn(
453
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
454
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
455
+ UserWarning,
456
+ )
457
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
458
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
459
+
460
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
461
+
462
+ model = model
463
+
464
+ # Disable dropout in the model
465
+ if args.disable_dropout:
466
+ disable_dropout_in_model(model)
467
+
468
+ if compute_metrics is None:
469
+ compute_metrics = compute_accuracy
470
+
471
+ if data_collator is None:
472
+ if processing_class is None:
473
+ raise ValueError(
474
+ "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
475
+ )
476
+
477
+ max_length = args.max_length
478
+
479
+ data_collator = RewardDataCollatorWithPadding(processing_class)
480
+
481
+ if args.remove_unused_columns:
482
+ try: # for bc before https://github.com/huggingface/transformers/pull/25435
483
+ args.remove_unused_columns = False
484
+ except FrozenInstanceError:
485
+ args = replace(args, remove_unused_columns=False)
486
+ # warn users
487
+ warnings.warn(
488
+ "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
489
+ " we have set it for you, but you should do it yourself in the future.",
490
+ UserWarning,
491
+ )
492
+
493
+ self.use_reward_data_collator = True
494
+ else:
495
+ self.use_reward_data_collator = False
496
+
497
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
498
+ # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
499
+ # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
500
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
501
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
502
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
503
+ # issued.
504
+ model.warnings_issued["estimate_tokens"] = True
505
+
506
+ if "input_ids_chosen" not in train_dataset.column_names:
507
+ with PartialState().main_process_first():
508
+ fn_kwargs = {"tokenizer": processing_class}
509
+ train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
510
+ train_dataset = train_dataset.map(
511
+ _tokenize,
512
+ batched=True,
513
+ fn_kwargs=fn_kwargs,
514
+ num_proc=args.dataset_num_proc,
515
+ )
516
+ # This filter is important because otherwise you get samples that exceed the model's context length and
517
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
518
+ # user might get surprised if N samples are missing from training.
519
+ train_dataset = train_dataset.filter(
520
+ lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
521
+ num_proc=args.dataset_num_proc,
522
+ )
523
+ if eval_dataset is not None:
524
+ eval_dataset = eval_dataset.map(
525
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
526
+ )
527
+ eval_dataset = eval_dataset.map(
528
+ _tokenize,
529
+ fn_kwargs=fn_kwargs,
530
+ batched=True,
531
+ num_proc=args.dataset_num_proc,
532
+ )
533
+ # This filter is important because otherwise you get samples that exceed the model's context length and
534
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
535
+ # user might get surprised if N samples are missing from training.
536
+ eval_dataset = eval_dataset.filter(
537
+ lambda x: len(x["input_ids_chosen"]) <= max_length
538
+ and len(x["input_ids_rejected"]) <= max_length,
539
+ num_proc=args.dataset_num_proc,
540
+ )
541
+
542
+ super().__init__(
543
+ model=model,
544
+ args=args,
545
+ data_collator=data_collator,
546
+ train_dataset=train_dataset,
547
+ eval_dataset=eval_dataset,
548
+ processing_class=processing_class,
549
+ model_init=model_init,
550
+ compute_metrics=compute_metrics,
551
+ callbacks=callbacks,
552
+ optimizers=optimizers,
553
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
554
+ )
555
+
556
+ # Add tags for models that have been loaded with the correct transformers version
557
+ if hasattr(self.model, "add_model_tags"):
558
+ self.model.add_model_tags(self._tag_names)
559
+
560
+ def compute_loss(
561
+ self,
562
+ model: Union[PreTrainedModel, nn.Module],
563
+ inputs: dict[str, Union[torch.Tensor, Any]],
564
+ return_outputs=False,
565
+ num_items_in_batch=None,
566
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
567
+ rewards_chosen = model(
568
+ input_ids=inputs["input_ids_chosen"],
569
+ attention_mask=inputs["attention_mask_chosen"],
570
+ return_dict=True,
571
+ )["logits"]
572
+ rewards_rejected = model(
573
+ input_ids=inputs["input_ids_rejected"],
574
+ attention_mask=inputs["attention_mask_rejected"],
575
+ return_dict=True,
576
+ )["logits"]
577
+ # calculate loss, optionally modulate with margin
578
+ if "margin" in inputs:
579
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
580
+ else:
581
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
582
+
583
+ if self.args.center_rewards_coefficient is not None:
584
+ loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
585
+
586
+ if return_outputs:
587
+ return loss, {
588
+ "rewards_chosen": rewards_chosen,
589
+ "rewards_rejected": rewards_rejected,
590
+ }
591
+ return loss
592
+
593
+ def prediction_step(
594
+ self,
595
+ model: Union[PreTrainedModel, nn.Module],
596
+ inputs: dict[str, Union[torch.Tensor, Any]],
597
+ prediction_loss_only: bool,
598
+ ignore_keys: Optional[list[str]] = None,
599
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
600
+ inputs = self._prepare_inputs(inputs)
601
+ if ignore_keys is None:
602
+ if hasattr(self.model, "config"):
603
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
604
+ else:
605
+ ignore_keys = []
606
+
607
+ with torch.no_grad():
608
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
609
+
610
+ if prediction_loss_only:
611
+ return (loss, None, None)
612
+
613
+ loss = loss.detach()
614
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
615
+ logits = nested_detach(logits)
616
+ # Stack accepted against rejected, mean over logits
617
+ # and softmax to get preferences between accepted and rejected to sum to 1
618
+ logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
619
+
620
+ labels = torch.zeros(logits.shape[0])
621
+ labels = self._prepare_inputs(labels)
622
+
623
+ return loss, logits, labels
624
+
625
+ def evaluate(self, *args, **kwargs):
626
+ num_print_samples = kwargs.pop("num_print_samples", 4)
627
+ self.visualize_samples(num_print_samples)
628
+ return super().evaluate(*args, **kwargs)
629
+
630
+ def visualize_samples(self, num_print_samples: int):
631
+ """
632
+ Visualize the reward model logits prediction
633
+
634
+ Args:
635
+ num_print_samples (`int`, defaults to `4`):
636
+ The number of samples to print. Set to `-1` to print all samples.
637
+ """
638
+ eval_dataloader = self.get_eval_dataloader()
639
+ table = defaultdict(list)
640
+ for _, inputs in enumerate(eval_dataloader):
641
+ _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
642
+ chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
643
+ rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
644
+ table["chosen_text"].extend(gather_object(chosen_text))
645
+ table["rejected_text"].extend(gather_object(rejected_text))
646
+ table["logits"].extend(
647
+ gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
648
+ )
649
+ if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
650
+ break
651
+ df = pd.DataFrame(table)
652
+ if self.accelerator.process_index == 0:
653
+ if is_rich_available():
654
+ print_rich_table(df[:num_print_samples])
655
+ if "wandb" in self.args.report_to:
656
+ import wandb
657
+
658
+ if wandb.run is not None:
659
+ wandb.log({"completions": wandb.Table(dataframe=df)})
660
+
661
+ if "comet_ml" in self.args.report_to:
662
+ log_table_to_comet_experiment(
663
+ name="completions.csv",
664
+ table=df,
665
+ )
666
+
667
+ # Ensure the model card is saved along with the checkpoint
668
+ def _save_checkpoint(self, model, trial):
669
+ if self.args.hub_model_id is None:
670
+ model_name = Path(self.args.output_dir).name
671
+ else:
672
+ model_name = self.args.hub_model_id.split("/")[-1]
673
+ self.create_model_card(model_name=model_name)
674
+ super()._save_checkpoint(model, trial)
675
+
676
+ def create_model_card(
677
+ self,
678
+ model_name: Optional[str] = None,
679
+ dataset_name: Optional[str] = None,
680
+ tags: Union[str, list[str], None] = None,
681
+ ):
682
+ """
683
+ Creates a draft of a model card using the information available to the `Trainer`.
684
+
685
+ Args:
686
+ model_name (`str` or `None`, *optional*, defaults to `None`):
687
+ Name of the model.
688
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
689
+ Name of the dataset used for training.
690
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
691
+ Tags to be associated with the model card.
692
+ """
693
+ if not self.is_world_process_zero():
694
+ return
695
+
696
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
697
+ base_model = self.model.config._name_or_path
698
+ else:
699
+ base_model = None
700
+
701
+ # normalize `tags` to a mutable set
702
+ if tags is None:
703
+ tags = set()
704
+ elif isinstance(tags, str):
705
+ tags = {tags}
706
+ else:
707
+ tags = set(tags)
708
+
709
+ if hasattr(self.model.config, "unsloth_version"):
710
+ tags.add("unsloth")
711
+
712
+ tags.update(self._tag_names)
713
+
714
+ model_card = generate_model_card(
715
+ base_model=base_model,
716
+ model_name=model_name,
717
+ hub_model_id=self.hub_model_id,
718
+ dataset_name=dataset_name,
719
+ tags=tags,
720
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
721
+ comet_url=get_comet_experiment_url(),
722
+ trainer_name="Reward",
723
+ )
724
+
725
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
726
+ class UnslothRewardTrainer(_UnslothRewardTrainer):
727
+ """
728
+
729
+ """
730
+ def __init__(
731
+ self,
732
+ model = None,
733
+ args = None,
734
+ data_collator = None,
735
+ train_dataset = None,
736
+ eval_dataset = None,
737
+ processing_class = None,
738
+ model_init = None,
739
+ compute_metrics = None,
740
+ callbacks = None,
741
+ preprocess_logits_for_metrics = None,
742
+ peft_config = None,
743
+ **kwargs
744
+ ):
745
+ if args is None: args = UnslothRewardConfig()
746
+ use_bf16 = getattr(args, 'bf16', False)
747
+ if type(use_bf16) is not bool: use_bf16 = False
748
+ use_fp16 = getattr(args, 'fp16', False)
749
+ if type(use_fp16) is not bool: use_fp16 = False
750
+ force_float32 = False
751
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
752
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
753
+ force_float32 = True
754
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
755
+ dtype = getattr(model.config, 'torch_dtype', None)
756
+ if dtype is None: dtype = model.get_input_embeddings().dtype
757
+ from unsloth_zoo.utils import _get_dtype
758
+ dtype = _get_dtype(dtype)
759
+ float16 = dtype == torch.float16
760
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
761
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
762
+ if force_float32:
763
+ args.fp16 = False
764
+ args.bf16 = False
765
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
766
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
767
+ args.fp16 = float16
768
+ args.bf16 = not float16
769
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
770
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
771
+ args.eval_strategy = 'steps'
772
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
773
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
774
+ if ga_steps is not None and ga_steps > 1:
775
+ from transformers import __version__ as transformers_version
776
+ if Version(transformers_version) <= Version('4.45.2'):
777
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
778
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
779
+ if getattr(args, 'eval_strategy', 'no') != 'no':
780
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
781
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
782
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
783
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
784
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
785
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
786
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
787
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
788
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
789
+ if force_float32:
790
+ args.bf16_full_eval = False
791
+ args.fp16_full_eval = False
792
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
793
+ args.bf16_full_eval = True
794
+ args.fp16_full_eval = False
795
+ elif not bf16_full_eval and not fp16_full_eval:
796
+ args.bf16_full_eval = args.bf16
797
+ args.fp16_full_eval = args.fp16
798
+ _output_logits = False
799
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
800
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
801
+ if _output_logits:
802
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
803
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
804
+ pass
805
+ else:
806
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
807
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
808
+ if args_max_seq_length is None and model_max_seq_length is not None:
809
+ max_seq_length = model.max_seq_length
810
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
811
+ if model is not None and hasattr(model, 'for_training'):
812
+ model.for_training()
813
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
814
+ if 'processing_class' in locals():
815
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
816
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
817
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
818
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
819
+ if not isinstance(data_collator, UnslothVisionDataCollator):
820
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
821
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
822
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
823
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
824
+ else:
825
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
826
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
827
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
828
+ if not isinstance(data_collator, UnslothVisionDataCollator):
829
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
830
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
831
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
832
+ else:
833
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
834
+ other_metrics = []
835
+
836
+ from unsloth_zoo.logging_utils import PatchRLStatistics
837
+ PatchRLStatistics('reward_trainer', other_metrics)
838
+
839
+ super().__init__(
840
+ model = model,
841
+ args = args,
842
+ data_collator = data_collator,
843
+ train_dataset = train_dataset,
844
+ eval_dataset = eval_dataset,
845
+ processing_class = processing_class,
846
+ model_init = model_init,
847
+ compute_metrics = compute_metrics,
848
+ callbacks = callbacks,
849
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
850
+ peft_config = peft_config,**kwargs)
851
+ if hasattr(self, 'neftune_hook_handle'):
852
+ self.neftune_hook_handle.remove()
853
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
854
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
855
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
856
+ pass
857
+ if hasattr(self, 'accelerator'):
858
+ scaler = self.accelerator.scaler
859
+ current_model = model
860
+ while hasattr(current_model, 'model'):
861
+ current_model.accelerator_scaler = scaler
862
+ current_model = current_model.model
863
+ current_model.accelerator_scaler = scaler
864
+ pass
865
+
866
+ pass
unsloth_compiled_cache/UnslothSFTTrainer.py ADDED
@@ -0,0 +1,1253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, dataclass, dataclasses, defaultdict, generate_model_card, get_act_offloading_ctx_manager, get_comet_experiment_url, get_peft_model, is_conversational, is_peft_available, is_wandb_available, nn, os, pack_dataset, pad, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, version, wandb, warnings, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_dataset, pad, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, peft, torch, os)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothSFTConfig(SFTConfig):
53
+ """
54
+
55
+ Configuration class for the [`SFTTrainer`].
56
+
57
+ This class includes only the parameters that are specific to SFT training. For a full list of training arguments,
58
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
59
+ differ from those in [`~transformers.TrainingArguments`].
60
+
61
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
62
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
63
+ command line.
64
+
65
+ Parameters:
66
+ > Parameters that control the model
67
+
68
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
69
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
70
+ argument of the [`SFTTrainer`] is provided as a string.
71
+ chat_template_path (`str` or `None`, *optional*, defaults to `None`):
72
+ If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
73
+ or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
74
+ ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
75
+ embedding layer is resized accordingly.
76
+
77
+ > Parameters that control the data preprocessing
78
+
79
+ dataset_text_field (`str`, *optional*, defaults to `"text"`):
80
+ Name of the column that contains text data in the dataset.
81
+ dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
82
+ Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
83
+ `skip_prepare_dataset`.
84
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
85
+ Number of processes to use for processing the dataset.
86
+ eos_token (`str` or `None`, *optional*, defaults to `None`):
87
+ Token used to indicate the end of a turn or sequence. If `None`, it defaults to
88
+ `processing_class.eos_token`.
89
+ pad_token (`int` or `None`, *optional*, defaults to `None`):
90
+ Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
91
+ it falls back to `processing_class.eos_token`.
92
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
93
+ Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
94
+ If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
95
+ packing (`bool`, *optional*, defaults to `False`):
96
+ Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
97
+ padding. Uses `max_length` to define sequence length.
98
+ packing_strategy (`str`, *optional*, defaults to `"bfd"`):
99
+ Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`.
100
+ padding_free (`bool`, *optional*, defaults to `False`):
101
+ Whether to perform forward passes without padding by flattening all sequences in the batch into a single
102
+ continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
103
+ supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When
104
+ packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this
105
+ parameter.
106
+ pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
107
+ If set, the sequences will be padded to a multiple of this value.
108
+ eval_packing (`bool` or `None`, *optional*, defaults to `None`):
109
+ Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
110
+
111
+ > Parameters that control the training
112
+
113
+ completion_only_loss (`bool` or `None`, *optional*, defaults to `None`):
114
+ Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed
115
+ only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If
116
+ `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
117
+ loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full
118
+ sequence for [language modeling](#language-modeling) datasets.
119
+ assistant_only_loss (`bool`, *optional*, defaults to `False`):
120
+ Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed
121
+ only on the assistant responses, which is supported only for [conversational](#conversational) datasets. If `False`,
122
+ loss is computed on the entire sequence.
123
+ activation_offloading (`bool`, *optional*, defaults to `False`):
124
+ Whether to offload the activations to the CPU.
125
+
126
+ """
127
+ vllm_sampling_params: Optional[Any] = field(
128
+ default = None,
129
+ metadata = {'help': 'vLLM SamplingParams'},
130
+ )
131
+ unsloth_num_chunks : Optional[int] = field(
132
+ default = -1,
133
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
134
+ )
135
+ max_seq_length : Optional[int] = field(
136
+ default = None,
137
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
138
+ )
139
+ def __init__(
140
+ self,
141
+ output_dir = None,
142
+ overwrite_output_dir = None,
143
+ do_train = False,
144
+ do_eval = False,
145
+ do_predict = False,
146
+ eval_strategy = 'no',
147
+ prediction_loss_only = False,
148
+ per_device_train_batch_size = 4,
149
+ per_device_eval_batch_size = 4,
150
+ per_gpu_train_batch_size = None,
151
+ per_gpu_eval_batch_size = None,
152
+ gradient_accumulation_steps = 2,
153
+ eval_accumulation_steps = 2,
154
+ eval_delay = 0,
155
+ torch_empty_cache_steps = 250,
156
+ learning_rate = 5e-05,
157
+ weight_decay = 0.01,
158
+ adam_beta1 = 0.9,
159
+ adam_beta2 = 0.999,
160
+ adam_epsilon = 1e-08,
161
+ max_grad_norm = 1.0,
162
+ num_train_epochs = 3.0,
163
+ max_steps = -1,
164
+ lr_scheduler_type = 'linear',
165
+ warmup_ratio = 0.1,
166
+ warmup_steps = 0,
167
+ log_level = 'passive',
168
+ log_level_replica = 'warning',
169
+ log_on_each_node = True,
170
+ logging_dir = None,
171
+ logging_strategy = 'steps',
172
+ logging_first_step = False,
173
+ logging_steps = 1,
174
+ logging_nan_inf_filter = False,
175
+ save_strategy = 'steps',
176
+ save_steps = 500,
177
+ save_total_limit = None,
178
+ save_safetensors = True,
179
+ save_on_each_node = False,
180
+ save_only_model = False,
181
+ restore_callback_states_from_checkpoint = False,
182
+ no_cuda = False,
183
+ use_cpu = False,
184
+ use_mps_device = False,
185
+ seed = 3407,
186
+ data_seed = 3407,
187
+ jit_mode_eval = False,
188
+ use_ipex = False,
189
+ bf16 = False,
190
+ fp16 = False,
191
+ fp16_opt_level = 'O1',
192
+ half_precision_backend = 'auto',
193
+ bf16_full_eval = False,
194
+ fp16_full_eval = False,
195
+ tf32 = None,
196
+ local_rank = -1,
197
+ ddp_backend = None,
198
+ tpu_num_cores = None,
199
+ tpu_metrics_debug = False,
200
+ debug = '',
201
+ dataloader_drop_last = False,
202
+ eval_steps = None,
203
+ dataloader_num_workers = 0,
204
+ dataloader_prefetch_factor = None,
205
+ past_index = -1,
206
+ run_name = None,
207
+ disable_tqdm = None,
208
+ remove_unused_columns = True,
209
+ label_names = None,
210
+ load_best_model_at_end = False,
211
+ metric_for_best_model = None,
212
+ greater_is_better = None,
213
+ ignore_data_skip = False,
214
+ fsdp = '',
215
+ fsdp_min_num_params = 0,
216
+ fsdp_config = None,
217
+ fsdp_transformer_layer_cls_to_wrap = None,
218
+ accelerator_config = None,
219
+ deepspeed = None,
220
+ label_smoothing_factor = 0.0,
221
+ optim = 'adamw_8bit',
222
+ optim_args = None,
223
+ adafactor = False,
224
+ group_by_length = False,
225
+ length_column_name = 'length',
226
+ report_to = None,
227
+ ddp_find_unused_parameters = None,
228
+ ddp_bucket_cap_mb = None,
229
+ ddp_broadcast_buffers = None,
230
+ dataloader_pin_memory = True,
231
+ dataloader_persistent_workers = False,
232
+ skip_memory_metrics = True,
233
+ use_legacy_prediction_loop = False,
234
+ push_to_hub = False,
235
+ resume_from_checkpoint = None,
236
+ hub_model_id = None,
237
+ hub_strategy = 'every_save',
238
+ hub_token = None,
239
+ hub_private_repo = None,
240
+ hub_always_push = False,
241
+ hub_revision = None,
242
+ gradient_checkpointing = False,
243
+ gradient_checkpointing_kwargs = None,
244
+ include_inputs_for_metrics = False,
245
+ eval_do_concat_batches = True,
246
+ fp16_backend = 'auto',
247
+ push_to_hub_model_id = None,
248
+ push_to_hub_organization = None,
249
+ push_to_hub_token = None,
250
+ mp_parameters = '',
251
+ auto_find_batch_size = True,
252
+ full_determinism = False,
253
+ torchdynamo = None,
254
+ ray_scope = 'last',
255
+ ddp_timeout = 1800,
256
+ torch_compile = False,
257
+ torch_compile_backend = None,
258
+ torch_compile_mode = None,
259
+ include_tokens_per_second = False,
260
+ include_num_input_tokens_seen = False,
261
+ neftune_noise_alpha = None,
262
+ optim_target_modules = None,
263
+ batch_eval_metrics = False,
264
+ eval_on_start = False,
265
+ use_liger_kernel = False,
266
+ liger_kernel_config = None,
267
+ eval_use_gather_object = False,
268
+ average_tokens_across_devices = True,
269
+ model_init_kwargs = None,
270
+ chat_template_path = None,
271
+ dataset_text_field = 'text',
272
+ dataset_kwargs = None,
273
+ dataset_num_proc = None,
274
+ eos_token = None,
275
+ pad_token = None,
276
+ max_length = 1024,
277
+ packing = False,
278
+ packing_strategy = 'bfd',
279
+ padding_free = False,
280
+ pad_to_multiple_of = None,
281
+ eval_packing = None,
282
+ completion_only_loss = None,
283
+ assistant_only_loss = False,
284
+ activation_offloading = False,
285
+ vllm_sampling_params = None,
286
+ unsloth_num_chunks = -1,
287
+ max_seq_length = None,
288
+ **kwargs,
289
+ ):
290
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
291
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
292
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
293
+ output_dir = 'unsloth_training_checkpoints'
294
+ save_strategy = 'no'
295
+ if dataset_num_proc is None:
296
+ from multiprocessing import cpu_count
297
+ dataset_num_proc = min(cpu_count()*2, 2)
298
+
299
+ super().__init__(
300
+ output_dir = output_dir,
301
+ overwrite_output_dir = overwrite_output_dir,
302
+ do_train = do_train,
303
+ do_eval = do_eval,
304
+ do_predict = do_predict,
305
+ eval_strategy = eval_strategy,
306
+ prediction_loss_only = prediction_loss_only,
307
+ per_device_train_batch_size = per_device_train_batch_size,
308
+ per_device_eval_batch_size = per_device_eval_batch_size,
309
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
310
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
311
+ gradient_accumulation_steps = gradient_accumulation_steps,
312
+ eval_accumulation_steps = eval_accumulation_steps,
313
+ eval_delay = eval_delay,
314
+ torch_empty_cache_steps = torch_empty_cache_steps,
315
+ learning_rate = learning_rate,
316
+ weight_decay = weight_decay,
317
+ adam_beta1 = adam_beta1,
318
+ adam_beta2 = adam_beta2,
319
+ adam_epsilon = adam_epsilon,
320
+ max_grad_norm = max_grad_norm,
321
+ num_train_epochs = num_train_epochs,
322
+ max_steps = max_steps,
323
+ lr_scheduler_type = lr_scheduler_type,
324
+ warmup_ratio = warmup_ratio,
325
+ warmup_steps = warmup_steps,
326
+ log_level = log_level,
327
+ log_level_replica = log_level_replica,
328
+ log_on_each_node = log_on_each_node,
329
+ logging_dir = logging_dir,
330
+ logging_strategy = logging_strategy,
331
+ logging_first_step = logging_first_step,
332
+ logging_steps = logging_steps,
333
+ logging_nan_inf_filter = logging_nan_inf_filter,
334
+ save_strategy = save_strategy,
335
+ save_steps = save_steps,
336
+ save_total_limit = save_total_limit,
337
+ save_safetensors = save_safetensors,
338
+ save_on_each_node = save_on_each_node,
339
+ save_only_model = save_only_model,
340
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
341
+ no_cuda = no_cuda,
342
+ use_cpu = use_cpu,
343
+ use_mps_device = use_mps_device,
344
+ seed = seed,
345
+ data_seed = data_seed,
346
+ jit_mode_eval = jit_mode_eval,
347
+ use_ipex = use_ipex,
348
+ bf16 = bf16,
349
+ fp16 = fp16,
350
+ fp16_opt_level = fp16_opt_level,
351
+ half_precision_backend = half_precision_backend,
352
+ bf16_full_eval = bf16_full_eval,
353
+ fp16_full_eval = fp16_full_eval,
354
+ tf32 = tf32,
355
+ local_rank = local_rank,
356
+ ddp_backend = ddp_backend,
357
+ tpu_num_cores = tpu_num_cores,
358
+ tpu_metrics_debug = tpu_metrics_debug,
359
+ debug = debug,
360
+ dataloader_drop_last = dataloader_drop_last,
361
+ eval_steps = eval_steps,
362
+ dataloader_num_workers = dataloader_num_workers,
363
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
364
+ past_index = past_index,
365
+ run_name = run_name,
366
+ disable_tqdm = disable_tqdm,
367
+ remove_unused_columns = remove_unused_columns,
368
+ label_names = label_names,
369
+ load_best_model_at_end = load_best_model_at_end,
370
+ metric_for_best_model = metric_for_best_model,
371
+ greater_is_better = greater_is_better,
372
+ ignore_data_skip = ignore_data_skip,
373
+ fsdp = fsdp,
374
+ fsdp_min_num_params = fsdp_min_num_params,
375
+ fsdp_config = fsdp_config,
376
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
377
+ accelerator_config = accelerator_config,
378
+ deepspeed = deepspeed,
379
+ label_smoothing_factor = label_smoothing_factor,
380
+ optim = optim,
381
+ optim_args = optim_args,
382
+ adafactor = adafactor,
383
+ group_by_length = group_by_length,
384
+ length_column_name = length_column_name,
385
+ report_to = report_to,
386
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
387
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
388
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
389
+ dataloader_pin_memory = dataloader_pin_memory,
390
+ dataloader_persistent_workers = dataloader_persistent_workers,
391
+ skip_memory_metrics = skip_memory_metrics,
392
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
393
+ push_to_hub = push_to_hub,
394
+ resume_from_checkpoint = resume_from_checkpoint,
395
+ hub_model_id = hub_model_id,
396
+ hub_strategy = hub_strategy,
397
+ hub_token = hub_token,
398
+ hub_private_repo = hub_private_repo,
399
+ hub_always_push = hub_always_push,
400
+ hub_revision = hub_revision,
401
+ gradient_checkpointing = gradient_checkpointing,
402
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
403
+ include_inputs_for_metrics = include_inputs_for_metrics,
404
+ eval_do_concat_batches = eval_do_concat_batches,
405
+ fp16_backend = fp16_backend,
406
+ push_to_hub_model_id = push_to_hub_model_id,
407
+ push_to_hub_organization = push_to_hub_organization,
408
+ push_to_hub_token = push_to_hub_token,
409
+ mp_parameters = mp_parameters,
410
+ auto_find_batch_size = auto_find_batch_size,
411
+ full_determinism = full_determinism,
412
+ torchdynamo = torchdynamo,
413
+ ray_scope = ray_scope,
414
+ ddp_timeout = ddp_timeout,
415
+ torch_compile = torch_compile,
416
+ torch_compile_backend = torch_compile_backend,
417
+ torch_compile_mode = torch_compile_mode,
418
+ include_tokens_per_second = include_tokens_per_second,
419
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
420
+ neftune_noise_alpha = neftune_noise_alpha,
421
+ optim_target_modules = optim_target_modules,
422
+ batch_eval_metrics = batch_eval_metrics,
423
+ eval_on_start = eval_on_start,
424
+ use_liger_kernel = use_liger_kernel,
425
+ liger_kernel_config = liger_kernel_config,
426
+ eval_use_gather_object = eval_use_gather_object,
427
+ average_tokens_across_devices = average_tokens_across_devices,
428
+ model_init_kwargs = model_init_kwargs,
429
+ chat_template_path = chat_template_path,
430
+ dataset_text_field = dataset_text_field,
431
+ dataset_kwargs = dataset_kwargs,
432
+ dataset_num_proc = dataset_num_proc,
433
+ eos_token = eos_token,
434
+ pad_token = pad_token,
435
+ max_length = max_length,
436
+ packing = packing,
437
+ packing_strategy = packing_strategy,
438
+ padding_free = padding_free,
439
+ pad_to_multiple_of = pad_to_multiple_of,
440
+ eval_packing = eval_packing,
441
+ completion_only_loss = completion_only_loss,
442
+ assistant_only_loss = assistant_only_loss,
443
+ activation_offloading = activation_offloading,**kwargs)
444
+ self.vllm_sampling_params = vllm_sampling_params
445
+ self.unsloth_num_chunks = unsloth_num_chunks
446
+ self.max_seq_length = max_seq_length
447
+ pass
448
+
449
+ class _UnslothSFTTrainer(Trainer):
450
+ """"""
451
+
452
+ _tag_names = ["trl", "sft"]
453
+
454
+ def __init__(
455
+ self,
456
+ model: Union[str, nn.Module, PreTrainedModel],
457
+ args: Optional[Union[SFTConfig, TrainingArguments]] = None,
458
+ data_collator: Optional[DataCollator] = None, # type: ignore
459
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
460
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
461
+ processing_class: Optional[
462
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
463
+ ] = None,
464
+ compute_loss_func: Optional[Callable] = None,
465
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
466
+ callbacks: Optional[list[TrainerCallback]] = None,
467
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
468
+ optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
469
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
470
+ peft_config: Optional["PeftConfig"] = None,
471
+ formatting_func: Optional[Callable[[dict], str]] = None,
472
+ ):
473
+ # Args
474
+ model_id = model if isinstance(model, str) else model.config._name_or_path
475
+ if args is None:
476
+ model_name = model_id.split("/")[-1]
477
+ args = SFTConfig(f"{model_name}-SFT")
478
+ elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
479
+ dict_args = args.to_dict()
480
+ dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
481
+ dict_args.pop("push_to_hub_token")
482
+ args = SFTConfig(**dict_args)
483
+
484
+ # Handle the tokenizer
485
+ if processing_class is None:
486
+ processing_class = AutoTokenizer.from_pretrained(model_id)
487
+
488
+ if args.eos_token is not None:
489
+ eos_token = args.eos_token
490
+ eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
491
+ if eos_token_id is None:
492
+ raise ValueError(
493
+ f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
494
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
495
+ "in the vocabulary before using it as an EOS token."
496
+ )
497
+ processing_class.eos_token_id = eos_token_id
498
+
499
+ # Model
500
+ if args.model_init_kwargs is not None and not isinstance(model, str):
501
+ warnings.warn(
502
+ "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
503
+ "The `model_init_kwargs` will be ignored."
504
+ )
505
+ if isinstance(model, str):
506
+ model = self._create_model_from_path(model, args)
507
+
508
+ if args.chat_template_path is not None:
509
+ if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
510
+ with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
511
+ processing_class.chat_template = chat_template_file.read()
512
+ added_tokens = []
513
+ else:
514
+ model, processing_class, added_tokens = clone_chat_template(
515
+ model, processing_class, args.chat_template_path
516
+ )
517
+ else:
518
+ added_tokens = []
519
+
520
+ # PEFT configuration and model wrapping
521
+ if False:
522
+ if added_tokens:
523
+ # Ensure that the added tokens are trainable
524
+ if peft_config.trainable_token_indices is None:
525
+ peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
526
+ elif "embed_tokens" not in peft_config.trainable_token_indices:
527
+ peft_config.trainable_token_indices["embed_tokens"] = added_tokens
528
+ else:
529
+ peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
530
+
531
+ # Ensure that the lm_head is trainable
532
+ if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
533
+ warnings.warn(
534
+ "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
535
+ "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
536
+ "tokens, leading to degraded generation quality. To fix this, add "
537
+ "`modules_to_save=['lm_head']` to your PEFT configuration."
538
+ )
539
+
540
+ if peft_config.modules_to_save is None:
541
+ peft_config.modules_to_save = ["lm_head"]
542
+ else:
543
+ peft_config.modules_to_save.append("lm_head")
544
+
545
+ if False:
546
+ pass
547
+
548
+ # Data collator
549
+ # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing
550
+ # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask.
551
+ self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd")
552
+ use_flash_attention = model.config._attn_implementation in [
553
+ "flash_attention_2",
554
+ "kernels-community/vllm-flash-attn3",
555
+ ]
556
+ if self.padding_free:
557
+ if data_collator is not None:
558
+ raise ValueError("Passing a custom data collator is not supported when using padding-free.")
559
+ if args.packing and args.packing_strategy == "wrapped":
560
+ warnings.warn(
561
+ "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not "
562
+ "recommended. Please refer to the documentation to understand why this is not recommended."
563
+ )
564
+ if not use_flash_attention:
565
+ warnings.warn(
566
+ "Padding-free training is enabled, but the attention implementation is not set to "
567
+ "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
568
+ "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
569
+ "other implementations may lead to unexpected behavior. To ensure compatibility, set "
570
+ "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
571
+ "attention mechanism can handle flattened sequences."
572
+ )
573
+ if args.per_device_train_batch_size == 1 and not args.packing:
574
+ warnings.warn(
575
+ "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
576
+ "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
577
+ "to at least 2."
578
+ )
579
+
580
+ dataset_sample = next(iter(train_dataset))
581
+ if args.completion_only_loss is None:
582
+ self.completion_only_loss = "prompt" in dataset_sample
583
+ else:
584
+ self.completion_only_loss = args.completion_only_loss
585
+
586
+ if data_collator is None:
587
+ # Get the pad token: if not provided, use the one from the processing class or the eos token
588
+ # if the processing class does not have a pad token.
589
+ pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
590
+ pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
591
+ if pad_token_id is None:
592
+ raise ValueError(
593
+ f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
594
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
595
+ "in the vocabulary before using it as a padding token."
596
+ )
597
+ data_collator = DataCollatorForLanguageModeling(
598
+ pad_token_id=pad_token_id,
599
+ completion_only_loss=self.completion_only_loss,
600
+ padding_free=self.padding_free,
601
+ # Using position_ids without flash_attn hurts the training
602
+ return_position_ids=use_flash_attention,
603
+ pad_to_multiple_of=args.pad_to_multiple_of,
604
+ )
605
+
606
+ if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
607
+ warnings.warn(
608
+ "You are using packing, but the attention implementation is not set to 'flash_attention_2' or "
609
+ "'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash "
610
+ "Attention is the only known attention mechanisms that reliably support this. Using other "
611
+ "implementations may lead to cross-contamination between batches. To avoid this, either disable "
612
+ "packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or "
613
+ "`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration."
614
+ )
615
+ if args.assistant_only_loss and not is_conversational(dataset_sample):
616
+ raise ValueError(
617
+ "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only "
618
+ "supported for conversational datasets."
619
+ )
620
+
621
+ # Dataset
622
+ preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
623
+ if preprocess_dataset:
624
+ if self.completion_only_loss and formatting_func:
625
+ raise ValueError(
626
+ "A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
627
+ "Using a formatter converts the dataset to a language modeling type, conflicting with "
628
+ "completion-only loss. To resolve this, apply your formatting function before passing the "
629
+ "dataset, or disable `completion_only_loss` in `SFTConfig`."
630
+ )
631
+ train_dataset = self._prepare_dataset(
632
+ train_dataset, processing_class, args, args.packing, formatting_func, "train"
633
+ )
634
+ if eval_dataset is not None:
635
+ packing = args.packing if args.eval_packing is None else args.eval_packing
636
+ if isinstance(eval_dataset, dict):
637
+ eval_dataset = {
638
+ key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
639
+ for key, dataset in eval_dataset.items()
640
+ }
641
+ else:
642
+ eval_dataset = self._prepare_dataset(
643
+ eval_dataset, processing_class, args, packing, formatting_func, "eval"
644
+ )
645
+
646
+ # Initialize the metrics
647
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
648
+ self._total_train_tokens = 0
649
+
650
+ # Initialize the Trainer. Parent class will handle:
651
+ # - DeepSpeed configuration [through create_accelerator_and_postprocess]
652
+ # - FSDP setup
653
+ # - Distributed training setup
654
+ # - Optimizer and scheduler creation
655
+
656
+ super().__init__(
657
+ model=model,
658
+ args=args,
659
+ data_collator=data_collator,
660
+ train_dataset=train_dataset,
661
+ eval_dataset=eval_dataset,
662
+ processing_class=processing_class,
663
+ compute_loss_func=compute_loss_func,
664
+ compute_metrics=compute_metrics,
665
+ callbacks=callbacks,
666
+ optimizers=optimizers,
667
+ optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
668
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
669
+ )
670
+
671
+ # Initialize activation offloading context
672
+ if self.args.activation_offloading:
673
+ self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
674
+ else:
675
+ self.maybe_activation_offload_context = contextlib.nullcontext()
676
+
677
+ # Add tags for models that have been loaded with the correct transformers version
678
+ if hasattr(self.model, "add_model_tags"):
679
+ self.model.add_model_tags(self._tag_names)
680
+
681
+ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
682
+ """Creates a model from a path or model identifier."""
683
+ model_init_kwargs = args.model_init_kwargs or {}
684
+ # Handle torch dtype
685
+ torch_dtype = model_init_kwargs.get("torch_dtype")
686
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
687
+ pass # torch_dtype is already a torch.dtype or "auto" or None
688
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
689
+ torch_dtype = getattr(torch, torch_dtype)
690
+ model_init_kwargs["torch_dtype"] = torch_dtype
691
+ else:
692
+ raise ValueError(
693
+ "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
694
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
695
+ )
696
+ # Disable caching if gradient checkpointing is enabled (not supported)
697
+ # if args.gradient_checkpointing:
698
+ # model_init_kwargs["use_cache"] = False
699
+
700
+ # Create model
701
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
702
+ return model
703
+
704
+ def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
705
+ """Prepares a model for PEFT training."""
706
+ if not is_peft_available():
707
+ raise ImportError("To use PeftModel, you need to install the `peft` library.")
708
+
709
+ # Handle quantized models (QLoRA)
710
+ is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
711
+
712
+ is_sharded_qlora = False
713
+ if getattr(model, "is_loaded_in_4bit", False):
714
+ # Check if model is sharded (FSDP/DS-Zero3)
715
+ for _, param in model.named_parameters():
716
+ if param.__class__.__name__ == "Params4bit":
717
+ is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
718
+ break
719
+
720
+ # Prepare model for kbit training if needed
721
+ if is_qlora and not is_sharded_qlora:
722
+ model = self._prepare_model_for_kbit_training(model, args)
723
+ # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
724
+ args = dataclasses.replace(args, gradient_checkpointing=False)
725
+ elif args.gradient_checkpointing:
726
+ model = self._enable_gradient_checkpointing(model, args)
727
+
728
+ # Create PEFT model
729
+ if peft_config is not None:
730
+ if (
731
+ version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
732
+ and getattr(model, "is_loaded_in_4bit", False)
733
+ and is_sharded_qlora
734
+ ):
735
+ model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
736
+ else:
737
+ model = get_peft_model(model, peft_config)
738
+
739
+ # Handle bf16 casting for 4-bit models
740
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
741
+ peft_module_casting_to_bf16(model)
742
+
743
+ return model
744
+
745
+ def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
746
+ """Prepares a quantized model for kbit training."""
747
+ prepare_model_kwargs = {
748
+ "use_gradient_checkpointing": args.gradient_checkpointing,
749
+ "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
750
+ }
751
+
752
+ return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
753
+
754
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
755
+ """Enables gradient checkpointing for the model."""
756
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
757
+ use_reentrant = (
758
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
759
+ )
760
+
761
+ if use_reentrant:
762
+ if hasattr(model, "enable_input_require_grads"):
763
+ model.enable_input_require_grads()
764
+ else:
765
+
766
+ def make_inputs_require_grad(module, input, output):
767
+ output.requires_grad_(True)
768
+
769
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
770
+
771
+ return model
772
+
773
+ def _prepare_dataset(
774
+ self,
775
+ dataset: Union[Dataset, IterableDataset],
776
+ processing_class,
777
+ args,
778
+ packing: bool,
779
+ formatting_func: Optional[Callable[[dict], str]],
780
+ dataset_name: str,
781
+ ) -> Union[Dataset, IterableDataset]:
782
+ # All Unsloth Zoo code licensed under LGPLv3
783
+ try:
784
+ if isinstance(dataset, ConstantLengthDataset): return dataset
785
+ except:
786
+ pass
787
+
788
+ map_kwargs = {}
789
+ use_desc = isinstance(dataset, Dataset)
790
+ is_vlm = hasattr(processing_class, "tokenizer")
791
+ tokenizer = processing_class
792
+ if is_vlm: tokenizer = processing_class.tokenizer
793
+
794
+ # Get max length
795
+ max_seq_length = getattr(args, "max_length", 0)
796
+ if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
797
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
798
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
799
+ if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
800
+ dataset_text_field = getattr(args, "dataset_text_field", "text")
801
+ do_truncation = max_seq_length != 0
802
+ do_formatting_func = False
803
+ do_tokenize = True
804
+
805
+ # Get correct column names
806
+ column_names = set(next(iter(dataset)).keys())
807
+ used_column_names = ["input_ids"]
808
+ if "attention_mask" in column_names:
809
+ used_column_names.append("attention_mask")
810
+
811
+ # Check if already tokenized so skip
812
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
813
+ if "labels" in column_names:
814
+ # Most likely forgot data collator!
815
+ if is_vlm and not hasattr(tokenizer, "pad"):
816
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
817
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
818
+ self.data_collator = DataCollatorForSeq2Seq(tokenizer)
819
+ used_column_names.append("labels")
820
+ do_tokenize = False
821
+ elif "input_ids" in column_names:
822
+ # Skip dataset prep, and set data collator
823
+ if is_vlm and not hasattr(tokenizer, "pad"):
824
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
825
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
826
+ self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
827
+ do_tokenize = False
828
+ elif dataset_text_field not in column_names:
829
+ do_formatting_func = True
830
+ if formatting_func is None:
831
+ raise RuntimeError("Unsloth: You must specify a `formatting_func`")
832
+ pass
833
+
834
+ if do_tokenize:
835
+ # Check double BOS tokens
836
+ if do_formatting_func:
837
+ test_text = formatting_func(next(iter(dataset)))
838
+ if not isinstance(test_text, list):
839
+ raise ValueError(
840
+ "Unsloth: The `formatting_func` should return a list of processed strings."
841
+ )
842
+ test_text = test_text[0]
843
+ else:
844
+ test_text = next(iter(dataset))[dataset_text_field][0]
845
+
846
+ # Get chat template
847
+ chat_template = getattr(processing_class, 'chat_template', '')
848
+ if chat_template == '' and is_vlm:
849
+ chat_template = getattr(tokenizer, 'chat_template', '')
850
+ if chat_template is None:
851
+ chat_template = ''
852
+
853
+ # Get bos_token
854
+ add_special_tokens = True
855
+ bos_token_1 = getattr(processing_class, 'bos_token', None)
856
+ bos_token_2 = getattr(tokenizer, 'bos_token', None)
857
+ bos_token = bos_token_1 or bos_token_2
858
+
859
+ if bos_token is not None:
860
+ if test_text.startswith(bos_token) or bos_token in chat_template:
861
+ add_special_tokens = False
862
+ print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
863
+ pass
864
+
865
+ # Create tokenize function
866
+ def _tokenize(example):
867
+ return tokenizer(
868
+ example[dataset_text_field] if not do_formatting_func else formatting_func(example),
869
+ truncation = do_truncation,
870
+ max_length = max_seq_length,
871
+ return_token_type_ids = False,
872
+ add_special_tokens = add_special_tokens,
873
+ )
874
+ pass
875
+
876
+ if not isinstance(dataset, IterableDataset):
877
+ map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
878
+ else:
879
+ map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
880
+
881
+ if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
882
+ dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
883
+
884
+ # If VLM, switch data collator since .pad is needed!
885
+ if is_vlm and not hasattr(processing_class, "pad"):
886
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
887
+ self.data_collator = data_collator
888
+ pass
889
+ pass
890
+ if packing:
891
+ # Try using new packing which works in TRL
892
+ try:
893
+ pack_dataset
894
+ except:
895
+ print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
896
+ return dataset
897
+
898
+ if max_seq_length == 0:
899
+ raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
900
+
901
+ if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
902
+ dataset = pack_dataset(
903
+ dataset.select_columns(used_column_names),
904
+ max_seq_length,
905
+ getattr(args, "packing_strategy", "bfd"),
906
+ map_kwargs,
907
+ )
908
+ pass
909
+ return dataset
910
+
911
+ def _set_signature_columns_if_needed(self):
912
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
913
+ # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
914
+ # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
915
+ # dataset. So we need to override the default signature columns to include "completion_mask" as well.
916
+ if self._signature_columns is None:
917
+ self._signature_columns = [
918
+ "input_ids",
919
+ "labels",
920
+ "seq_lengths",
921
+ "completion_mask",
922
+ "assistant_masks",
923
+ ]
924
+
925
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
926
+ outputs = super().compute_loss(
927
+ model,
928
+ inputs,
929
+ return_outputs = return_outputs,
930
+ num_items_in_batch = num_items_in_batch,
931
+ )
932
+ return outputs
933
+
934
+ # Override training step to add activation offloading context.
935
+ def training_step(self, *args, **kwargs):
936
+ with self.maybe_activation_offload_context:
937
+ return super().training_step(*args, **kwargs)
938
+
939
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
940
+ mode = "train" if self.model.training else "eval"
941
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
942
+
943
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
944
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
945
+ if mode == "eval":
946
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
947
+
948
+ logs = {**logs, **metrics}
949
+ super().log(logs, start_time)
950
+ self._metrics[mode].clear()
951
+
952
+ # Ensure the model card is saved along with the checkpoint
953
+ def _save_checkpoint(self, model, trial):
954
+ if self.args.hub_model_id is None:
955
+ model_name = Path(self.args.output_dir).name
956
+ else:
957
+ model_name = self.args.hub_model_id.split("/")[-1]
958
+ self.create_model_card(model_name=model_name)
959
+ super()._save_checkpoint(model, trial)
960
+
961
+ def create_model_card(
962
+ self,
963
+ model_name: Optional[str] = None,
964
+ dataset_name: Optional[str] = None,
965
+ tags: Union[str, list[str], None] = None,
966
+ ):
967
+ """
968
+ Creates a draft of a model card using the information available to the `Trainer`.
969
+
970
+ Args:
971
+ model_name (`str` or `None`, *optional*, defaults to `None`):
972
+ Name of the model.
973
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
974
+ Name of the dataset used for training.
975
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
976
+ Tags to be associated with the model card.
977
+ """
978
+ if not self.is_world_process_zero():
979
+ return
980
+
981
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
982
+ base_model = self.model.config._name_or_path
983
+ else:
984
+ base_model = None
985
+
986
+ # normalize `tags` to a mutable set
987
+ if tags is None:
988
+ tags = set()
989
+ elif isinstance(tags, str):
990
+ tags = {tags}
991
+ else:
992
+ tags = set(tags)
993
+
994
+ if hasattr(self.model.config, "unsloth_version"):
995
+ tags.add("unsloth")
996
+
997
+ tags.update(self._tag_names)
998
+
999
+ model_card = generate_model_card(
1000
+ base_model=base_model,
1001
+ model_name=model_name,
1002
+ hub_model_id=self.hub_model_id,
1003
+ dataset_name=dataset_name,
1004
+ tags=list(tags),
1005
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
1006
+ comet_url=get_comet_experiment_url(),
1007
+ trainer_name="SFT",
1008
+ )
1009
+
1010
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1011
+ class UnslothSFTTrainer(_UnslothSFTTrainer):
1012
+ """
1013
+
1014
+ Trainer for Supervised Fine-Tuning (SFT) method.
1015
+
1016
+ This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
1017
+
1018
+ Example:
1019
+
1020
+ ```python
1021
+ from datasets import load_dataset
1022
+ from trl import SFTTrainer
1023
+
1024
+ dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
1025
+
1026
+ trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
1027
+ trainer.train()
1028
+ ```
1029
+
1030
+ Args:
1031
+ model (`Union[str, PreTrainedModel]`):
1032
+ Model to be trained. Can be either:
1033
+
1034
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
1035
+ path to a *directory* containing model weights saved using
1036
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1037
+ using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
1038
+ `args.model_init_kwargs`.
1039
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1040
+ args ([`SFTConfig`], *optional*, defaults to `None`):
1041
+ Configuration for this trainer. If `None`, a default configuration is used.
1042
+ data_collator (`DataCollator`, *optional*):
1043
+ Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
1044
+ Will default to a custom [`DataCollatorForLanguageModeling`].
1045
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1046
+ Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
1047
+ [prompt-completion](#prompt-completion) type. The format of the samples can be either:
1048
+
1049
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1050
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1051
+ and content).
1052
+
1053
+ The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
1054
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1055
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1056
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
1057
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
1058
+ with [`~transformers.AutoTokenizer.from_pretrained`].
1059
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
1060
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
1061
+ in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1062
+
1063
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1064
+ method.
1065
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
1066
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
1067
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
1068
+ optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
1069
+ A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
1070
+ `args`. Incompatible with the `optimizers` argument.
1071
+
1072
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
1073
+ initializing the Trainer.
1074
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
1075
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
1076
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
1077
+ by this function will be reflected in the predictions received by `compute_metrics`.
1078
+
1079
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
1080
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
1081
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1082
+ formatting_func (`Optional[Callable]`):
1083
+ Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
1084
+ converts the dataset into a [language modeling](#language-modeling) type.
1085
+
1086
+ """
1087
+ def __init__(
1088
+ self,
1089
+ model,
1090
+ args = None,
1091
+ data_collator = None,
1092
+ train_dataset = None,
1093
+ eval_dataset = None,
1094
+ processing_class = None,
1095
+ compute_loss_func = None,
1096
+ compute_metrics = None,
1097
+ callbacks = None,
1098
+ optimizer_cls_and_kwargs = None,
1099
+ preprocess_logits_for_metrics = None,
1100
+ peft_config = None,
1101
+ formatting_func = None,
1102
+ **kwargs
1103
+ ):
1104
+ if args is None: args = UnslothSFTConfig()
1105
+ use_bf16 = getattr(args, 'bf16', False)
1106
+ if type(use_bf16) is not bool: use_bf16 = False
1107
+ use_fp16 = getattr(args, 'fp16', False)
1108
+ if type(use_fp16) is not bool: use_fp16 = False
1109
+ force_float32 = False
1110
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1111
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1112
+ force_float32 = True
1113
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1114
+ dtype = getattr(model.config, 'torch_dtype', None)
1115
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1116
+ from unsloth_zoo.utils import _get_dtype
1117
+ dtype = _get_dtype(dtype)
1118
+ float16 = dtype == torch.float16
1119
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1120
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1121
+ if force_float32:
1122
+ args.fp16 = False
1123
+ args.bf16 = False
1124
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1125
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1126
+ args.fp16 = float16
1127
+ args.bf16 = not float16
1128
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1129
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1130
+ args.eval_strategy = 'steps'
1131
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1132
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1133
+ if ga_steps is not None and ga_steps > 1:
1134
+ from transformers import __version__ as transformers_version
1135
+ if Version(transformers_version) <= Version('4.45.2'):
1136
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1137
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1138
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1139
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1140
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1141
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1142
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1143
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1144
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1145
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1146
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1147
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1148
+ if force_float32:
1149
+ args.bf16_full_eval = False
1150
+ args.fp16_full_eval = False
1151
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1152
+ args.bf16_full_eval = True
1153
+ args.fp16_full_eval = False
1154
+ elif not bf16_full_eval and not fp16_full_eval:
1155
+ args.bf16_full_eval = args.bf16
1156
+ args.fp16_full_eval = args.fp16
1157
+ _output_logits = False
1158
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1159
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1160
+ if _output_logits:
1161
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1162
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1163
+ pass
1164
+ else:
1165
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1166
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1167
+ if args_max_seq_length is None and model_max_seq_length is not None:
1168
+ max_seq_length = model.max_seq_length
1169
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1170
+ if 'max_length' not in locals() and not hasattr(args, 'max_length'):
1171
+ pass
1172
+ else:
1173
+ if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
1174
+ if hasattr(args, 'max_length'):
1175
+ args.max_length = args.max_seq_length
1176
+ max_length = args.max_length
1177
+ else:
1178
+ model_max_length = getattr(model, 'max_seq_length', None)
1179
+ if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
1180
+ if model_max_length is not None:
1181
+ args.max_length = model_max_length
1182
+ max_length = args.max_length
1183
+ elif hasattr(args, 'max_length') and args.max_length is not None:
1184
+ max_length = args.max_length
1185
+ # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
1186
+ setattr(model, 'max_seq_length', max_length)
1187
+ else:
1188
+ print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
1189
+ args.max_length = 1024
1190
+ if model is not None and hasattr(model, 'for_training'):
1191
+ model.for_training()
1192
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1193
+ if 'processing_class' in locals():
1194
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1195
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1196
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1197
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1198
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1199
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1200
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1201
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1202
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1203
+ else:
1204
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1205
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1206
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1207
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1208
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1209
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1210
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1211
+ else:
1212
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1213
+ other_metrics = []
1214
+
1215
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1216
+ PatchRLStatistics('sft_trainer', other_metrics)
1217
+ IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
1218
+ from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
1219
+ from unsloth_zoo.training_utils import fix_zero_training_loss
1220
+ if 'tokenizer' not in locals(): tokenizer = processing_class
1221
+ fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1222
+ fix_zero_training_loss(model, tokenizer, train_dataset)
1223
+
1224
+ super().__init__(
1225
+ model = model,
1226
+ args = args,
1227
+ data_collator = data_collator,
1228
+ train_dataset = train_dataset,
1229
+ eval_dataset = eval_dataset,
1230
+ processing_class = processing_class,
1231
+ compute_loss_func = compute_loss_func,
1232
+ compute_metrics = compute_metrics,
1233
+ callbacks = callbacks,
1234
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1235
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1236
+ peft_config = peft_config,
1237
+ formatting_func = formatting_func,**kwargs)
1238
+ if hasattr(self, 'neftune_hook_handle'):
1239
+ self.neftune_hook_handle.remove()
1240
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1241
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1242
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1243
+ pass
1244
+ if hasattr(self, 'accelerator'):
1245
+ scaler = self.accelerator.scaler
1246
+ current_model = model
1247
+ while hasattr(current_model, 'model'):
1248
+ current_model.accelerator_scaler = scaler
1249
+ current_model = current_model.model
1250
+ current_model.accelerator_scaler = scaler
1251
+ pass
1252
+
1253
+ pass
unsloth_compiled_cache/UnslothXPOTrainer.py ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.8.8
3
+ 2025.8.9
4
+ 4.55.2
5
+ 0.21.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
+ from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
14
+
15
+
16
+ import os
17
+ from typing import *
18
+ from dataclasses import dataclass, field
19
+ from packaging.version import Version
20
+ import torch
21
+ import numpy as np
22
+ from contextlib import nullcontext
23
+ from torch.nn import functional as F
24
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
+
26
+ torch_compile_options = {
27
+ "epilogue_fusion" : True,
28
+ "max_autotune" : False,
29
+ "shape_padding" : True,
30
+ "trace.enabled" : False,
31
+ "triton.cudagraphs" : False,
32
+ }
33
+
34
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
+ def chunked_selective_log_softmax(logits, index):
36
+ # Split into 4 chunks only
37
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
+ all_per_token_logps = []
40
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
+ chunk_logits = chunk_logits.to(torch.float32)
43
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
+ per_token_logps = selected_logits - logsumexp_values
46
+ all_per_token_logps.append(per_token_logps)
47
+ pass
48
+ all_per_token_logps = torch.concat(all_per_token_logps)
49
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
+ return all_per_token_logps
51
+ @dataclass
52
+ class UnslothXPOConfig(XPOConfig):
53
+ """
54
+
55
+ Configuration class for the [`XPOTrainer`].
56
+
57
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
58
+
59
+ Parameters:
60
+ alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
61
+ Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
62
+ and the last alpha is used for the rest of the epochs.
63
+
64
+ """
65
+ vllm_sampling_params: Optional[Any] = field(
66
+ default = None,
67
+ metadata = {'help': 'vLLM SamplingParams'},
68
+ )
69
+ unsloth_num_chunks : Optional[int] = field(
70
+ default = -1,
71
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
72
+ )
73
+ max_seq_length : Optional[int] = field(
74
+ default = None,
75
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
76
+ )
77
+ def __init__(
78
+ self,
79
+ output_dir = None,
80
+ overwrite_output_dir = None,
81
+ do_train = False,
82
+ do_eval = False,
83
+ do_predict = False,
84
+ eval_strategy = 'no',
85
+ prediction_loss_only = False,
86
+ per_device_train_batch_size = 4,
87
+ per_device_eval_batch_size = 4,
88
+ per_gpu_train_batch_size = None,
89
+ per_gpu_eval_batch_size = None,
90
+ gradient_accumulation_steps = 2,
91
+ eval_accumulation_steps = 2,
92
+ eval_delay = 0,
93
+ torch_empty_cache_steps = 250,
94
+ learning_rate = 5e-05,
95
+ weight_decay = 0.01,
96
+ adam_beta1 = 0.9,
97
+ adam_beta2 = 0.999,
98
+ adam_epsilon = 1e-08,
99
+ max_grad_norm = 1.0,
100
+ num_train_epochs = 3.0,
101
+ max_steps = -1,
102
+ lr_scheduler_type = 'linear',
103
+ warmup_ratio = 0.1,
104
+ warmup_steps = 0,
105
+ log_level = 'passive',
106
+ log_level_replica = 'warning',
107
+ log_on_each_node = True,
108
+ logging_dir = None,
109
+ logging_strategy = 'steps',
110
+ logging_first_step = False,
111
+ logging_steps = 1,
112
+ logging_nan_inf_filter = False,
113
+ save_strategy = 'steps',
114
+ save_steps = 500,
115
+ save_total_limit = None,
116
+ save_safetensors = True,
117
+ save_on_each_node = False,
118
+ save_only_model = False,
119
+ restore_callback_states_from_checkpoint = False,
120
+ no_cuda = False,
121
+ use_cpu = False,
122
+ use_mps_device = False,
123
+ seed = 3407,
124
+ data_seed = 3407,
125
+ jit_mode_eval = False,
126
+ use_ipex = False,
127
+ bf16 = False,
128
+ fp16 = False,
129
+ fp16_opt_level = 'O1',
130
+ half_precision_backend = 'auto',
131
+ bf16_full_eval = False,
132
+ fp16_full_eval = False,
133
+ tf32 = None,
134
+ local_rank = -1,
135
+ ddp_backend = None,
136
+ tpu_num_cores = None,
137
+ tpu_metrics_debug = False,
138
+ debug = '',
139
+ dataloader_drop_last = False,
140
+ eval_steps = None,
141
+ dataloader_num_workers = 0,
142
+ dataloader_prefetch_factor = None,
143
+ past_index = -1,
144
+ run_name = None,
145
+ disable_tqdm = None,
146
+ remove_unused_columns = True,
147
+ label_names = None,
148
+ load_best_model_at_end = False,
149
+ metric_for_best_model = None,
150
+ greater_is_better = None,
151
+ ignore_data_skip = False,
152
+ fsdp = '',
153
+ fsdp_min_num_params = 0,
154
+ fsdp_config = None,
155
+ fsdp_transformer_layer_cls_to_wrap = None,
156
+ accelerator_config = None,
157
+ deepspeed = None,
158
+ label_smoothing_factor = 0.0,
159
+ optim = 'adamw_8bit',
160
+ optim_args = None,
161
+ adafactor = False,
162
+ group_by_length = False,
163
+ length_column_name = 'length',
164
+ report_to = None,
165
+ ddp_find_unused_parameters = None,
166
+ ddp_bucket_cap_mb = None,
167
+ ddp_broadcast_buffers = None,
168
+ dataloader_pin_memory = True,
169
+ dataloader_persistent_workers = False,
170
+ skip_memory_metrics = True,
171
+ use_legacy_prediction_loop = False,
172
+ push_to_hub = False,
173
+ resume_from_checkpoint = None,
174
+ hub_model_id = None,
175
+ hub_strategy = 'every_save',
176
+ hub_token = None,
177
+ hub_private_repo = None,
178
+ hub_always_push = False,
179
+ hub_revision = None,
180
+ gradient_checkpointing = False,
181
+ gradient_checkpointing_kwargs = None,
182
+ include_inputs_for_metrics = False,
183
+ eval_do_concat_batches = True,
184
+ fp16_backend = 'auto',
185
+ push_to_hub_model_id = None,
186
+ push_to_hub_organization = None,
187
+ push_to_hub_token = None,
188
+ mp_parameters = '',
189
+ auto_find_batch_size = True,
190
+ full_determinism = False,
191
+ torchdynamo = None,
192
+ ray_scope = 'last',
193
+ ddp_timeout = 1800,
194
+ torch_compile = False,
195
+ torch_compile_backend = None,
196
+ torch_compile_mode = None,
197
+ include_tokens_per_second = False,
198
+ include_num_input_tokens_seen = False,
199
+ neftune_noise_alpha = None,
200
+ optim_target_modules = None,
201
+ batch_eval_metrics = False,
202
+ eval_on_start = False,
203
+ use_liger_kernel = False,
204
+ liger_kernel_config = None,
205
+ eval_use_gather_object = False,
206
+ average_tokens_across_devices = True,
207
+ reward_model_path = None,
208
+ judge = None,
209
+ max_new_tokens = 64,
210
+ max_length = 512,
211
+ temperature = 0.9,
212
+ missing_eos_penalty = None,
213
+ loss_type = 'sigmoid',
214
+ dataset_num_proc = None,
215
+ disable_dropout = True,
216
+ use_vllm = False,
217
+ vllm_model_impl = 'vllm',
218
+ gpu_memory_utilization = 0.55,
219
+ ds3_gather_for_generation = True,
220
+ model_init_kwargs = None,
221
+ vllm_sampling_params = None,
222
+ unsloth_num_chunks = -1,
223
+ max_seq_length = None,
224
+ **kwargs,
225
+ ):
226
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
227
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
228
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
229
+ output_dir = 'unsloth_training_checkpoints'
230
+ save_strategy = 'no'
231
+ if dataset_num_proc is None:
232
+ from multiprocessing import cpu_count
233
+ dataset_num_proc = min(cpu_count()*2, 2)
234
+ if temperature <= 0:
235
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
236
+ elif temperature >= 10:
237
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
238
+
239
+
240
+ super().__init__(
241
+ output_dir = output_dir,
242
+ overwrite_output_dir = overwrite_output_dir,
243
+ do_train = do_train,
244
+ do_eval = do_eval,
245
+ do_predict = do_predict,
246
+ eval_strategy = eval_strategy,
247
+ prediction_loss_only = prediction_loss_only,
248
+ per_device_train_batch_size = per_device_train_batch_size,
249
+ per_device_eval_batch_size = per_device_eval_batch_size,
250
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
251
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
252
+ gradient_accumulation_steps = gradient_accumulation_steps,
253
+ eval_accumulation_steps = eval_accumulation_steps,
254
+ eval_delay = eval_delay,
255
+ torch_empty_cache_steps = torch_empty_cache_steps,
256
+ learning_rate = learning_rate,
257
+ weight_decay = weight_decay,
258
+ adam_beta1 = adam_beta1,
259
+ adam_beta2 = adam_beta2,
260
+ adam_epsilon = adam_epsilon,
261
+ max_grad_norm = max_grad_norm,
262
+ num_train_epochs = num_train_epochs,
263
+ max_steps = max_steps,
264
+ lr_scheduler_type = lr_scheduler_type,
265
+ warmup_ratio = warmup_ratio,
266
+ warmup_steps = warmup_steps,
267
+ log_level = log_level,
268
+ log_level_replica = log_level_replica,
269
+ log_on_each_node = log_on_each_node,
270
+ logging_dir = logging_dir,
271
+ logging_strategy = logging_strategy,
272
+ logging_first_step = logging_first_step,
273
+ logging_steps = logging_steps,
274
+ logging_nan_inf_filter = logging_nan_inf_filter,
275
+ save_strategy = save_strategy,
276
+ save_steps = save_steps,
277
+ save_total_limit = save_total_limit,
278
+ save_safetensors = save_safetensors,
279
+ save_on_each_node = save_on_each_node,
280
+ save_only_model = save_only_model,
281
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
282
+ no_cuda = no_cuda,
283
+ use_cpu = use_cpu,
284
+ use_mps_device = use_mps_device,
285
+ seed = seed,
286
+ data_seed = data_seed,
287
+ jit_mode_eval = jit_mode_eval,
288
+ use_ipex = use_ipex,
289
+ bf16 = bf16,
290
+ fp16 = fp16,
291
+ fp16_opt_level = fp16_opt_level,
292
+ half_precision_backend = half_precision_backend,
293
+ bf16_full_eval = bf16_full_eval,
294
+ fp16_full_eval = fp16_full_eval,
295
+ tf32 = tf32,
296
+ local_rank = local_rank,
297
+ ddp_backend = ddp_backend,
298
+ tpu_num_cores = tpu_num_cores,
299
+ tpu_metrics_debug = tpu_metrics_debug,
300
+ debug = debug,
301
+ dataloader_drop_last = dataloader_drop_last,
302
+ eval_steps = eval_steps,
303
+ dataloader_num_workers = dataloader_num_workers,
304
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
305
+ past_index = past_index,
306
+ run_name = run_name,
307
+ disable_tqdm = disable_tqdm,
308
+ remove_unused_columns = remove_unused_columns,
309
+ label_names = label_names,
310
+ load_best_model_at_end = load_best_model_at_end,
311
+ metric_for_best_model = metric_for_best_model,
312
+ greater_is_better = greater_is_better,
313
+ ignore_data_skip = ignore_data_skip,
314
+ fsdp = fsdp,
315
+ fsdp_min_num_params = fsdp_min_num_params,
316
+ fsdp_config = fsdp_config,
317
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
318
+ accelerator_config = accelerator_config,
319
+ deepspeed = deepspeed,
320
+ label_smoothing_factor = label_smoothing_factor,
321
+ optim = optim,
322
+ optim_args = optim_args,
323
+ adafactor = adafactor,
324
+ group_by_length = group_by_length,
325
+ length_column_name = length_column_name,
326
+ report_to = report_to,
327
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
328
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
329
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
330
+ dataloader_pin_memory = dataloader_pin_memory,
331
+ dataloader_persistent_workers = dataloader_persistent_workers,
332
+ skip_memory_metrics = skip_memory_metrics,
333
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
334
+ push_to_hub = push_to_hub,
335
+ resume_from_checkpoint = resume_from_checkpoint,
336
+ hub_model_id = hub_model_id,
337
+ hub_strategy = hub_strategy,
338
+ hub_token = hub_token,
339
+ hub_private_repo = hub_private_repo,
340
+ hub_always_push = hub_always_push,
341
+ hub_revision = hub_revision,
342
+ gradient_checkpointing = gradient_checkpointing,
343
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
344
+ include_inputs_for_metrics = include_inputs_for_metrics,
345
+ eval_do_concat_batches = eval_do_concat_batches,
346
+ fp16_backend = fp16_backend,
347
+ push_to_hub_model_id = push_to_hub_model_id,
348
+ push_to_hub_organization = push_to_hub_organization,
349
+ push_to_hub_token = push_to_hub_token,
350
+ mp_parameters = mp_parameters,
351
+ auto_find_batch_size = auto_find_batch_size,
352
+ full_determinism = full_determinism,
353
+ torchdynamo = torchdynamo,
354
+ ray_scope = ray_scope,
355
+ ddp_timeout = ddp_timeout,
356
+ torch_compile = torch_compile,
357
+ torch_compile_backend = torch_compile_backend,
358
+ torch_compile_mode = torch_compile_mode,
359
+ include_tokens_per_second = include_tokens_per_second,
360
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
361
+ neftune_noise_alpha = neftune_noise_alpha,
362
+ optim_target_modules = optim_target_modules,
363
+ batch_eval_metrics = batch_eval_metrics,
364
+ eval_on_start = eval_on_start,
365
+ use_liger_kernel = use_liger_kernel,
366
+ liger_kernel_config = liger_kernel_config,
367
+ eval_use_gather_object = eval_use_gather_object,
368
+ average_tokens_across_devices = average_tokens_across_devices,
369
+ reward_model_path = reward_model_path,
370
+ judge = judge,
371
+ max_new_tokens = max_new_tokens,
372
+ max_length = max_length,
373
+ temperature = temperature,
374
+ missing_eos_penalty = missing_eos_penalty,
375
+ loss_type = loss_type,
376
+ dataset_num_proc = dataset_num_proc,
377
+ disable_dropout = disable_dropout,
378
+ use_vllm = use_vllm,
379
+ vllm_model_impl = vllm_model_impl,
380
+ gpu_memory_utilization = gpu_memory_utilization,
381
+ ds3_gather_for_generation = ds3_gather_for_generation,
382
+ model_init_kwargs = model_init_kwargs,**kwargs)
383
+ self.vllm_sampling_params = vllm_sampling_params
384
+ self.unsloth_num_chunks = unsloth_num_chunks
385
+ self.max_seq_length = max_seq_length
386
+ pass
387
+
388
+ class _UnslothXPOTrainer(OnlineDPOTrainer):
389
+ r""""""
390
+
391
+ _tag_names = ["trl", "xpo"]
392
+
393
+ def __init__(
394
+ self,
395
+ model: Union[PreTrainedModel, nn.Module] = None,
396
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
397
+ reward_model: Optional[nn.Module] = None,
398
+ judge: Optional[BasePairwiseJudge] = None,
399
+ args: Optional[XPOConfig] = None,
400
+ data_collator: Optional[Callable] = None,
401
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
402
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
403
+ processing_class: Optional[
404
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
405
+ ] = None,
406
+ peft_config: Optional[dict] = None,
407
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
408
+ callbacks: Optional[list[TrainerCallback]] = None,
409
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
410
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
411
+ ) -> None:
412
+ super().__init__(
413
+ model=model,
414
+ ref_model=ref_model,
415
+ judge=judge,
416
+ reward_model=reward_model,
417
+ args=args,
418
+ data_collator=data_collator,
419
+ train_dataset=train_dataset,
420
+ eval_dataset=eval_dataset,
421
+ processing_class=processing_class,
422
+ reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
423
+ peft_config=peft_config,
424
+ compute_metrics=compute_metrics,
425
+ callbacks=callbacks,
426
+ optimizers=optimizers,
427
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
428
+ )
429
+
430
+ self._alpha = self.args.alpha
431
+
432
+ # Overwrite the stats dictionary to include XPO specific statistics
433
+ self.stats = {
434
+ # Remove "non_score_reward", "rlhf_reward", "scores"
435
+ # Add "loss/dpo", "loss/xpo"
436
+ "loss/dpo": [],
437
+ "loss/xpo": [],
438
+ "objective/kl": [],
439
+ "objective/entropy": [],
440
+ "rewards/chosen": [],
441
+ "rewards/rejected": [],
442
+ "rewards/accuracies": [],
443
+ "rewards/margins": [],
444
+ "logps/chosen": [],
445
+ "logps/rejected": [],
446
+ # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
447
+ "val/model_contain_eos_token": [],
448
+ "val/ref_contain_eos_token": [],
449
+ "alpha": [],
450
+ "beta": [],
451
+ }
452
+ if self.reward_model is not None:
453
+ # Replace "scores" by "model_scores" and "ref_scores"
454
+ self.stats["objective/model_scores"] = []
455
+ self.stats["objective/ref_scores"] = []
456
+ self.stats["objective/scores_margin"] = []
457
+
458
+ @property
459
+ def alpha(self):
460
+ if isinstance(self._alpha, list):
461
+ epoch = self.state.epoch
462
+ return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
463
+ else:
464
+ return self._alpha
465
+
466
+ def _generate_completions(self, prompts, model):
467
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
468
+ model_output = unwrapped_policy_model_for_gen.generate(
469
+ input_ids=prompts["input_ids"],
470
+ attention_mask=prompts["attention_mask"],
471
+ generation_config=self.generation_config,
472
+ )
473
+
474
+ actual_model_for_ref_generation: torch.nn.Module
475
+ if self.ref_model is None:
476
+ unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
477
+
478
+ if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
479
+ actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
480
+ else:
481
+ actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
482
+ else:
483
+ actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
484
+
485
+ with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
486
+ ref_output = final_ref_model_for_gen.generate(
487
+ input_ids=prompts["input_ids"],
488
+ attention_mask=prompts["attention_mask"],
489
+ generation_config=self.generation_config,
490
+ )
491
+
492
+ return model_output, ref_output
493
+
494
+ def _process_completions(self, model_output, ref_output, prompts):
495
+ context_length = prompts["input_ids"].shape[1]
496
+
497
+ # Process model completions
498
+ model_completion_ids = model_output[:, context_length:]
499
+ model_completion_ids, model_completion_mask = truncate_right(
500
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
501
+ )
502
+ model_data = {
503
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
504
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
505
+ "raw": prompts["raw"],
506
+ }
507
+
508
+ # Process reference model completions
509
+ ref_completion_ids = ref_output[:, context_length:]
510
+ ref_completion_ids, ref_completion_mask = truncate_right(
511
+ ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
512
+ )
513
+ ref_data = {
514
+ "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
515
+ "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
516
+ "raw": prompts["raw"],
517
+ }
518
+
519
+ return model_data, ref_data
520
+
521
+ def _compute_rewards(self, model_data, ref_data, context_length):
522
+ with torch.no_grad():
523
+ _, model_scores, _ = get_reward(
524
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
525
+ )
526
+ _, ref_scores, _ = get_reward(
527
+ self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
528
+ )
529
+
530
+ # Apply EOS penalty if needed
531
+ if self.args.missing_eos_penalty is not None:
532
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
533
+ ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
534
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
535
+ ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
536
+
537
+ return model_scores, ref_scores
538
+
539
+ def _compute_judge(self, model_data, ref_data, context_length):
540
+ prompts = model_data["raw"]
541
+ model_data_completions = self.processing_class.batch_decode(
542
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
543
+ )
544
+ model_data_completions = [completion.strip() for completion in model_data_completions]
545
+
546
+ ref_data_completions = self.processing_class.batch_decode(
547
+ ref_data["input_ids"][:, context_length:], skip_special_tokens=True
548
+ )
549
+ ref_data_completions = [completion.strip() for completion in ref_data_completions]
550
+
551
+ if is_conversational({"prompt": prompts[0]}):
552
+ model_data_completions = [
553
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
554
+ ]
555
+ environment = jinja2.Environment()
556
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
557
+ prompts = [template.render(messages=message) for message in prompts]
558
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
559
+
560
+ ref_data_completions = [
561
+ [{"role": "assistant", "content": completion}] for completion in ref_data_completions
562
+ ]
563
+ ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
564
+
565
+ ranks_of_first_completion = self.judge.judge(
566
+ prompts,
567
+ list(zip(model_data_completions, ref_data_completions)),
568
+ )
569
+ # convert ranks to a True/False mask:
570
+ # when rank == 0, it means the first completion is the best
571
+ # when rank == 1, it means the second completion is the best
572
+ return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
573
+
574
+ def _compute_logprobs(self, model, model_data, ref_data, context_length):
575
+ def compute_logprobs_for_data(m, data):
576
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
577
+ logits = output.logits[:, context_length - 1 : -1]
578
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
579
+ return token_logprobs
580
+
581
+ # Compute logprobs for model completions
582
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
583
+ # Compute logprobs for model on reference completions (for XPO loss)
584
+ model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
585
+
586
+ # Compute logprobs for reference model completions
587
+ with torch.no_grad():
588
+ if self.ref_model is None:
589
+ with model.disable_adapter():
590
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
591
+ ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
592
+ else:
593
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
594
+ ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
595
+
596
+ # Mask padding tokens
597
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
598
+ ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
599
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
600
+ model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
601
+ ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
602
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
603
+
604
+ return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
605
+
606
+ def _compute_losses(
607
+ self,
608
+ model_logprobs_model_data,
609
+ model_logprobs_ref_data,
610
+ ref_logprobs_ref_data,
611
+ ref_logprobs_model_data,
612
+ chosen_mask,
613
+ ):
614
+ # Compute log probs
615
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
616
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
617
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
618
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
619
+
620
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
621
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
622
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
623
+
624
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
625
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
626
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
627
+
628
+ # Compute logits as the difference between chosen and rejected log ratios
629
+ logits = chosen_log_ratios - rejected_log_ratios
630
+
631
+ if self.args.loss_type == "sigmoid":
632
+ dpo_losses = -F.logsigmoid(self.beta * logits)
633
+ elif self.args.loss_type == "ipo":
634
+ dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
635
+ else:
636
+ raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
637
+
638
+ # Compute XPO specific loss
639
+ xpo_losses = self.alpha * model_logprobs_ref_data_sum
640
+
641
+ # Total loss
642
+ loss = (dpo_losses + xpo_losses).mean()
643
+
644
+ return loss, dpo_losses, xpo_losses
645
+
646
+ def _log_statistics(
647
+ self,
648
+ model_data,
649
+ ref_data,
650
+ model_logprobs_model_data,
651
+ model_logprobs_ref_data,
652
+ ref_logprobs_ref_data,
653
+ ref_logprobs_model_data,
654
+ chosen_mask,
655
+ dpo_losses,
656
+ xpo_losses,
657
+ context_length,
658
+ model_scores=None,
659
+ ref_scores=None,
660
+ ):
661
+ # Helper function to gather and compute mean
662
+ def gather_mean(tensor):
663
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
664
+
665
+ # Log losses
666
+ self.stats["loss/dpo"].append(gather_mean(dpo_losses))
667
+ self.stats["loss/xpo"].append(gather_mean(xpo_losses))
668
+
669
+ # Log scores
670
+ if self.reward_model is not None:
671
+ self.stats["objective/model_scores"].append(gather_mean(model_scores))
672
+ self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
673
+ self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
674
+
675
+ # Log logprobs
676
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
677
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
678
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
679
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
680
+
681
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
682
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
683
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
684
+
685
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
686
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
687
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
688
+
689
+ self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
690
+ self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
691
+
692
+ # Log rewards
693
+ # Compute various statistics
694
+ chosen_rewards = chosen_log_ratios * self.beta
695
+ rejected_rewards = rejected_log_ratios * self.beta
696
+ self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
697
+ self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
698
+
699
+ # Calculate KL divergence for model and ref data
700
+ kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
701
+ kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
702
+ mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
703
+ self.stats["objective/kl"].append(gather_mean(mean_kl))
704
+
705
+ # Calculate entropy for model and ref data
706
+ entropy_model_data = -model_logprobs_model_data.sum(1)
707
+ entropy_ref_data = -model_logprobs_ref_data.sum(1)
708
+ mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
709
+ self.stats["objective/entropy"].append(gather_mean(mean_entropy))
710
+
711
+ # Calculate margins
712
+ margin = chosen_rewards - rejected_rewards
713
+ self.stats["rewards/margins"].append(gather_mean(margin.mean()))
714
+
715
+ # Calculate accuracy
716
+ accuracy = (margin > 0).float()
717
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
718
+
719
+ # Log EOS token statistics
720
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
721
+ ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
722
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
723
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
724
+
725
+ # Log alpha and beta
726
+ self.stats["alpha"].append(self.alpha)
727
+ self.stats["beta"].append(self.beta)
728
+
729
+ def training_step(
730
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
731
+ ) -> torch.Tensor:
732
+ model.train()
733
+
734
+ # Apply chat template and tokenize the input
735
+ batch_size = len(next(iter(inputs.values())))
736
+ prompts = inputs["prompt"]
737
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
738
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
739
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
740
+ inputs = self.data_collator(inputs)
741
+
742
+ # need the prompt_ only
743
+ inputs = self._prepare_inputs(inputs)
744
+ context_length = inputs["prompt_input_ids"].shape[1]
745
+ prompts = {
746
+ "input_ids": inputs["prompt_input_ids"],
747
+ "attention_mask": inputs["prompt_attention_mask"],
748
+ "raw": prompts,
749
+ }
750
+ del inputs
751
+
752
+ # Sample completions from both the model and the reference model
753
+ model_output, ref_output = self._generate_completions(prompts, model)
754
+
755
+ # Process model completions
756
+ model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
757
+
758
+ # Compute rewards
759
+ if self.reward_model is not None:
760
+ model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
761
+ chosen_mask = model_scores >= ref_scores
762
+ else:
763
+ model_scores, ref_scores = None, None
764
+ chosen_mask = self._compute_judge(model_data, ref_data, context_length)
765
+
766
+ # Compute logprobs
767
+ model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
768
+ self._compute_logprobs(model, model_data, ref_data, context_length)
769
+ )
770
+
771
+ # Compute loss
772
+ loss, dpo_losses, xpo_losses = self._compute_losses(
773
+ model_logprobs_model_data,
774
+ model_logprobs_ref_data,
775
+ ref_logprobs_ref_data,
776
+ ref_logprobs_model_data,
777
+ chosen_mask,
778
+ )
779
+
780
+ # Log everything
781
+ self._log_statistics(
782
+ model_data,
783
+ ref_data,
784
+ model_logprobs_model_data.detach(),
785
+ model_logprobs_ref_data.detach(),
786
+ ref_logprobs_ref_data,
787
+ ref_logprobs_model_data,
788
+ chosen_mask,
789
+ dpo_losses.detach(),
790
+ xpo_losses.detach(),
791
+ context_length,
792
+ model_scores,
793
+ ref_scores,
794
+ )
795
+
796
+ if (
797
+ self.args.torch_empty_cache_steps is not None
798
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
799
+ ):
800
+ empty_cache()
801
+
802
+ kwargs = {}
803
+ # For LOMO optimizers you need to explicitly use the learning rate
804
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
805
+ kwargs["learning_rate"] = self._get_learning_rate()
806
+
807
+ if self.args.n_gpu > 1:
808
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
809
+
810
+ if self.use_apex:
811
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
812
+ scaled_loss.backward()
813
+ else:
814
+ self.accelerator.backward(loss, **kwargs)
815
+
816
+ return loss.detach() / self.args.gradient_accumulation_steps
817
+
818
+ def create_model_card(
819
+ self,
820
+ model_name: Optional[str] = None,
821
+ dataset_name: Optional[str] = None,
822
+ tags: Union[str, list[str], None] = None,
823
+ ):
824
+ """
825
+ Creates a draft of a model card using the information available to the `Trainer`.
826
+
827
+ Args:
828
+ model_name (`str` or `None`, *optional*, defaults to `None`):
829
+ Name of the model.
830
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
831
+ Name of the dataset used for training.
832
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
833
+ Tags to be associated with the model card.
834
+ """
835
+ if not self.is_world_process_zero():
836
+ return
837
+
838
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
839
+ base_model = self.model.config._name_or_path
840
+ else:
841
+ base_model = None
842
+
843
+ # normalize `tags` to a mutable set
844
+ if tags is None:
845
+ tags = set()
846
+ elif isinstance(tags, str):
847
+ tags = {tags}
848
+ else:
849
+ tags = set(tags)
850
+
851
+ if hasattr(self.model.config, "unsloth_version"):
852
+ tags.add("unsloth")
853
+
854
+ tags.update(self._tag_names)
855
+
856
+ citation = textwrap.dedent("""\
857
+ @article{jung2024binary,
858
+ title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
859
+ author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
860
+ year = 2024,
861
+ eprint = {arXiv:2405.21046}
862
+ }""")
863
+
864
+ model_card = generate_model_card(
865
+ base_model=base_model,
866
+ model_name=model_name,
867
+ hub_model_id=self.hub_model_id,
868
+ dataset_name=dataset_name,
869
+ tags=tags,
870
+ wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
871
+ comet_url=get_comet_experiment_url(),
872
+ trainer_name="XPO",
873
+ trainer_citation=citation,
874
+ paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
875
+ paper_id="2405.21046",
876
+ )
877
+
878
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
879
+ class UnslothXPOTrainer(_UnslothXPOTrainer):
880
+ """
881
+
882
+ Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
883
+
884
+ Args:
885
+ model (`transformers.PreTrainedModel`):
886
+ The model to train, preferably an `AutoModelForCausalLM`.
887
+ ref_model (`PreTrainedModelWrapper`):
888
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
889
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
890
+ architecture as the model to be optimized.
891
+ reward_model (`transformers.PreTrainedModel`):
892
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
893
+ judge (`BasePairwiseJudge`):
894
+ The judge to use for pairwise comparison of model completions.
895
+ args (`XPOConfig`):
896
+ The XPO config arguments to use for training.
897
+ data_collator (`transformers.DataCollator`):
898
+ The data collator to use for training. If None is specified, the default data collator
899
+ (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
900
+ sequences in the batch, given a dataset of paired sequences.
901
+ train_dataset (`datasets.Dataset`):
902
+ The dataset to use for training.
903
+ eval_dataset (`datasets.Dataset`):
904
+ The dataset to use for evaluation.
905
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
906
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
907
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
908
+ reuse the fine-tuned model.
909
+ peft_config (`dict`):
910
+ The peft config to use for training.
911
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
912
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
913
+ metric values.
914
+ callbacks (`list[transformers.TrainerCallback]`):
915
+ The callbacks to use for training.
916
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
917
+ The optimizer and scheduler to use for training.
918
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
919
+ The function to use to preprocess the logits before computing the metrics.
920
+
921
+ """
922
+ def __init__(
923
+ self,
924
+ model = None,
925
+ ref_model = None,
926
+ reward_model = None,
927
+ judge = None,
928
+ args = None,
929
+ data_collator = None,
930
+ train_dataset = None,
931
+ eval_dataset = None,
932
+ processing_class = None,
933
+ peft_config = None,
934
+ compute_metrics = None,
935
+ callbacks = None,
936
+ preprocess_logits_for_metrics = None,
937
+ **kwargs
938
+ ):
939
+ if args is None: args = UnslothXPOConfig()
940
+ use_bf16 = getattr(args, 'bf16', False)
941
+ if type(use_bf16) is not bool: use_bf16 = False
942
+ use_fp16 = getattr(args, 'fp16', False)
943
+ if type(use_fp16) is not bool: use_fp16 = False
944
+ force_float32 = False
945
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
946
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
947
+ force_float32 = True
948
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
949
+ dtype = getattr(model.config, 'torch_dtype', None)
950
+ if dtype is None: dtype = model.get_input_embeddings().dtype
951
+ from unsloth_zoo.utils import _get_dtype
952
+ dtype = _get_dtype(dtype)
953
+ float16 = dtype == torch.float16
954
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
955
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
956
+ if force_float32:
957
+ args.fp16 = False
958
+ args.bf16 = False
959
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
960
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
961
+ args.fp16 = float16
962
+ args.bf16 = not float16
963
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
964
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
965
+ args.eval_strategy = 'steps'
966
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
967
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
968
+ if ga_steps is not None and ga_steps > 1:
969
+ from transformers import __version__ as transformers_version
970
+ if Version(transformers_version) <= Version('4.45.2'):
971
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
972
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
973
+ if getattr(args, 'eval_strategy', 'no') != 'no':
974
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
975
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
976
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
977
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
978
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
979
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
980
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
981
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
982
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
983
+ if force_float32:
984
+ args.bf16_full_eval = False
985
+ args.fp16_full_eval = False
986
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
987
+ args.bf16_full_eval = True
988
+ args.fp16_full_eval = False
989
+ elif not bf16_full_eval and not fp16_full_eval:
990
+ args.bf16_full_eval = args.bf16
991
+ args.fp16_full_eval = args.fp16
992
+ _output_logits = False
993
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
994
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
995
+ if _output_logits:
996
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
997
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
998
+ pass
999
+ else:
1000
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1001
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1002
+ if args_max_seq_length is None and model_max_seq_length is not None:
1003
+ max_seq_length = model.max_seq_length
1004
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1005
+ if model is not None and hasattr(model, 'for_training'):
1006
+ model.for_training()
1007
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1008
+ if 'processing_class' in locals():
1009
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1010
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1011
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1012
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1013
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1014
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1015
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1016
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1017
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1018
+ else:
1019
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1020
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1021
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1022
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1023
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1024
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1025
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1026
+ else:
1027
+ data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1028
+ other_metrics = []
1029
+
1030
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1031
+ PatchRLStatistics('xpo_trainer', other_metrics)
1032
+
1033
+ super().__init__(
1034
+ model = model,
1035
+ ref_model = ref_model,
1036
+ reward_model = reward_model,
1037
+ judge = judge,
1038
+ args = args,
1039
+ data_collator = data_collator,
1040
+ train_dataset = train_dataset,
1041
+ eval_dataset = eval_dataset,
1042
+ processing_class = processing_class,
1043
+ peft_config = peft_config,
1044
+ compute_metrics = compute_metrics,
1045
+ callbacks = callbacks,
1046
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1047
+ if hasattr(self, 'neftune_hook_handle'):
1048
+ self.neftune_hook_handle.remove()
1049
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1050
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1051
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1052
+ pass
1053
+ if hasattr(self, 'accelerator'):
1054
+ scaler = self.accelerator.scaler
1055
+ current_model = model
1056
+ while hasattr(current_model, 'model'):
1057
+ current_model.accelerator_scaler = scaler
1058
+ current_model = current_model.model
1059
+ current_model.accelerator_scaler = scaler
1060
+ pass
1061
+
1062
+ pass
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-312.pyc ADDED
Binary file (34.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc ADDED
Binary file (86.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc ADDED
Binary file (72.1 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-312.pyc ADDED
Binary file (46.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2db1830ac6e205085f6ce375f4fb879adecec7b68103785df83a47a1a0bd7ae0
3
+ size 116397
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60910b7f8cf9df113b6079a6ad701b546ed5bd6b071c7adb30d4f5702e8539d0
3
+ size 136535
unsloth_compiled_cache/__pycache__/UnslothIterativeSFTTrainer.cpython-312.pyc ADDED
Binary file (41.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc ADDED
Binary file (89.8 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc ADDED
Binary file (44.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc ADDED
Binary file (70.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc ADDED
Binary file (63.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc ADDED
Binary file (63.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc ADDED
Binary file (54.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc ADDED
Binary file (38.1 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc ADDED
Binary file (54.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc ADDED
Binary file (46.6 kB). View file