fix: add lr scheduler kwargs to Trainer (#972)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
@@ -692,6 +692,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
692 |
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
693 |
else "cosine"
|
694 |
)
|
|
|
|
|
|
|
695 |
training_arguments_kwargs["weight_decay"] = (
|
696 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
697 |
)
|
|
|
692 |
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
693 |
else "cosine"
|
694 |
)
|
695 |
+
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
696 |
+
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
697 |
+
)
|
698 |
training_arguments_kwargs["weight_decay"] = (
|
699 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
700 |
)
|