Filippo Broggini winglian commited on
Commit
18f8119
1 Parent(s): afb5dd9

FEAT: add tagging support to axolotl for DPOTrainer (#1209)

Browse files

* Add AxolotlDPOTrainer

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +36 -19
src/axolotl/core/trainer_builder.py CHANGED
@@ -59,6 +59,22 @@ except ImportError:
59
  LOG = logging.getLogger("axolotl.core.trainer_builder")
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @dataclass
63
  class AxolotlTrainingArguments(TrainingArguments):
64
  """
@@ -349,30 +365,13 @@ class AxolotlTrainer(Trainer):
349
  # return (loss, outputs) if return_outputs else loss
350
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
351
 
352
- def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
353
- if isinstance(tag_names, str):
354
- tag_names = [tag_names]
355
-
356
- if kwargs is not None:
357
- if "tags" not in kwargs:
358
- kwargs["tags"] = tag_names
359
- elif "tags" in kwargs and isinstance(kwargs["tags"], list):
360
- kwargs["tags"].extend(tag_names)
361
- elif "tags" in kwargs and isinstance(kwargs["tags"], str):
362
- tag_names.append(kwargs["tags"])
363
- kwargs["tags"] = tag_names
364
-
365
- return kwargs
366
-
367
  @wraps(Trainer.push_to_hub)
368
  def push_to_hub(self, *args, **kwargs) -> str:
369
  """
370
  Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
371
  model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
372
  """
373
- kwargs = self._sanitize_kwargs_for_tagging(
374
- tag_names=self.tag_names, kwargs=kwargs
375
- )
376
 
377
  return super().push_to_hub(*args, **kwargs)
378
 
@@ -471,6 +470,24 @@ class ReLoRATrainer(AxolotlTrainer):
471
  return self.lr_scheduler
472
 
473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  class TrainerBuilderBase(abc.ABC):
475
  """
476
  Base class for trainer builder
@@ -1076,7 +1093,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1076
  dpo_trainer_kwargs[
1077
  "precompute_ref_log_probs"
1078
  ] = self.cfg.precompute_ref_log_probs
1079
- dpo_trainer = DPOTrainer(
1080
  self.model,
1081
  self.model_ref,
1082
  args=training_args,
 
59
  LOG = logging.getLogger("axolotl.core.trainer_builder")
60
 
61
 
62
+ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
63
+ if isinstance(tag_names, str):
64
+ tag_names = [tag_names]
65
+
66
+ if kwargs is not None:
67
+ if "tags" not in kwargs:
68
+ kwargs["tags"] = tag_names
69
+ elif "tags" in kwargs and isinstance(kwargs["tags"], list):
70
+ kwargs["tags"].extend(tag_names)
71
+ elif "tags" in kwargs and isinstance(kwargs["tags"], str):
72
+ tag_names.append(kwargs["tags"])
73
+ kwargs["tags"] = tag_names
74
+
75
+ return kwargs
76
+
77
+
78
  @dataclass
79
  class AxolotlTrainingArguments(TrainingArguments):
80
  """
 
365
  # return (loss, outputs) if return_outputs else loss
366
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  @wraps(Trainer.push_to_hub)
369
  def push_to_hub(self, *args, **kwargs) -> str:
370
  """
371
  Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
372
  model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
373
  """
374
+ kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
 
 
375
 
376
  return super().push_to_hub(*args, **kwargs)
377
 
 
470
  return self.lr_scheduler
471
 
472
 
473
+ class AxolotlDPOTrainer(DPOTrainer):
474
+ """
475
+ Extend the base DPOTrainer for axolotl helpers
476
+ """
477
+
478
+ tag_names = ["axolotl", "dpo"]
479
+
480
+ @wraps(DPOTrainer.push_to_hub)
481
+ def push_to_hub(self, *args, **kwargs) -> str:
482
+ """
483
+ Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
484
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
485
+ """
486
+ kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
487
+
488
+ return super().push_to_hub(*args, **kwargs)
489
+
490
+
491
  class TrainerBuilderBase(abc.ABC):
492
  """
493
  Base class for trainer builder
 
1093
  dpo_trainer_kwargs[
1094
  "precompute_ref_log_probs"
1095
  ] = self.cfg.precompute_ref_log_probs
1096
+ dpo_trainer = AxolotlDPOTrainer(
1097
  self.model,
1098
  self.model_ref,
1099
  args=training_args,