FEAT: add tagging support to axolotl for DPOTrainer (#1209)
Browse files* Add AxolotlDPOTrainer
* chore: lint
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/core/trainer_builder.py
CHANGED
@@ -59,6 +59,22 @@ except ImportError:
|
|
59 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
@dataclass
|
63 |
class AxolotlTrainingArguments(TrainingArguments):
|
64 |
"""
|
@@ -349,30 +365,13 @@ class AxolotlTrainer(Trainer):
|
|
349 |
# return (loss, outputs) if return_outputs else loss
|
350 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
351 |
|
352 |
-
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
353 |
-
if isinstance(tag_names, str):
|
354 |
-
tag_names = [tag_names]
|
355 |
-
|
356 |
-
if kwargs is not None:
|
357 |
-
if "tags" not in kwargs:
|
358 |
-
kwargs["tags"] = tag_names
|
359 |
-
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
360 |
-
kwargs["tags"].extend(tag_names)
|
361 |
-
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
362 |
-
tag_names.append(kwargs["tags"])
|
363 |
-
kwargs["tags"] = tag_names
|
364 |
-
|
365 |
-
return kwargs
|
366 |
-
|
367 |
@wraps(Trainer.push_to_hub)
|
368 |
def push_to_hub(self, *args, **kwargs) -> str:
|
369 |
"""
|
370 |
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
371 |
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
372 |
"""
|
373 |
-
kwargs = self.
|
374 |
-
tag_names=self.tag_names, kwargs=kwargs
|
375 |
-
)
|
376 |
|
377 |
return super().push_to_hub(*args, **kwargs)
|
378 |
|
@@ -471,6 +470,24 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
471 |
return self.lr_scheduler
|
472 |
|
473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
class TrainerBuilderBase(abc.ABC):
|
475 |
"""
|
476 |
Base class for trainer builder
|
@@ -1076,7 +1093,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
1076 |
dpo_trainer_kwargs[
|
1077 |
"precompute_ref_log_probs"
|
1078 |
] = self.cfg.precompute_ref_log_probs
|
1079 |
-
dpo_trainer =
|
1080 |
self.model,
|
1081 |
self.model_ref,
|
1082 |
args=training_args,
|
|
|
59 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
60 |
|
61 |
|
62 |
+
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
63 |
+
if isinstance(tag_names, str):
|
64 |
+
tag_names = [tag_names]
|
65 |
+
|
66 |
+
if kwargs is not None:
|
67 |
+
if "tags" not in kwargs:
|
68 |
+
kwargs["tags"] = tag_names
|
69 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
70 |
+
kwargs["tags"].extend(tag_names)
|
71 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
72 |
+
tag_names.append(kwargs["tags"])
|
73 |
+
kwargs["tags"] = tag_names
|
74 |
+
|
75 |
+
return kwargs
|
76 |
+
|
77 |
+
|
78 |
@dataclass
|
79 |
class AxolotlTrainingArguments(TrainingArguments):
|
80 |
"""
|
|
|
365 |
# return (loss, outputs) if return_outputs else loss
|
366 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
@wraps(Trainer.push_to_hub)
|
369 |
def push_to_hub(self, *args, **kwargs) -> str:
|
370 |
"""
|
371 |
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
372 |
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
373 |
"""
|
374 |
+
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
|
|
|
|
375 |
|
376 |
return super().push_to_hub(*args, **kwargs)
|
377 |
|
|
|
470 |
return self.lr_scheduler
|
471 |
|
472 |
|
473 |
+
class AxolotlDPOTrainer(DPOTrainer):
|
474 |
+
"""
|
475 |
+
Extend the base DPOTrainer for axolotl helpers
|
476 |
+
"""
|
477 |
+
|
478 |
+
tag_names = ["axolotl", "dpo"]
|
479 |
+
|
480 |
+
@wraps(DPOTrainer.push_to_hub)
|
481 |
+
def push_to_hub(self, *args, **kwargs) -> str:
|
482 |
+
"""
|
483 |
+
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
484 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
485 |
+
"""
|
486 |
+
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
487 |
+
|
488 |
+
return super().push_to_hub(*args, **kwargs)
|
489 |
+
|
490 |
+
|
491 |
class TrainerBuilderBase(abc.ABC):
|
492 |
"""
|
493 |
Base class for trainer builder
|
|
|
1093 |
dpo_trainer_kwargs[
|
1094 |
"precompute_ref_log_probs"
|
1095 |
] = self.cfg.precompute_ref_log_probs
|
1096 |
+
dpo_trainer = AxolotlDPOTrainer(
|
1097 |
self.model,
|
1098 |
self.model_ref,
|
1099 |
args=training_args,
|