support adamw and grad norm hyperparams
Browse files
src/axolotl/utils/trainer.py
CHANGED
@@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
115 |
# TODO search Path("./") for one
|
116 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
training_args = transformers.TrainingArguments(
|
119 |
per_device_train_batch_size=cfg.micro_batch_size,
|
120 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
115 |
# TODO search Path("./") for one
|
116 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
117 |
|
118 |
+
if cfg.adam_beta1:
|
119 |
+
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
120 |
+
if cfg.adam_beta2:
|
121 |
+
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
122 |
+
if cfg.adam_epsilon:
|
123 |
+
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
124 |
+
if cfg.max_grad_norm:
|
125 |
+
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
126 |
+
|
127 |
training_args = transformers.TrainingArguments(
|
128 |
per_device_train_batch_size=cfg.micro_batch_size,
|
129 |
per_device_eval_batch_size=cfg.eval_batch_size
|