fix missing fp16 kwarg
Browse files
src/axolotl/utils/trainer.py
CHANGED
@@ -56,6 +56,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
56 |
training_arguments_kwargs["bf16_full_eval"] = True
|
57 |
else:
|
58 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
|
|
59 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
60 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
61 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
|
56 |
training_arguments_kwargs["bf16_full_eval"] = True
|
57 |
else:
|
58 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
59 |
+
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
|
60 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
61 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
62 |
training_arguments_kwargs["logging_steps"] = logging_steps
|