use save_strategy from config if available (#434)
Browse files* use save_strategy from config if available
* update docs for save_strategy
- README.md +1 -0
- src/axolotl/utils/trainer.py +7 -1
README.md
CHANGED
@@ -472,6 +472,7 @@ warmup_steps: 100
|
|
472 |
learning_rate: 0.00003
|
473 |
lr_quadratic_warmup:
|
474 |
logging_steps:
|
|
|
475 |
save_steps: # leave empty to save at each epoch
|
476 |
eval_steps:
|
477 |
save_total_limit: # checkpoints saved at a time
|
|
|
472 |
learning_rate: 0.00003
|
473 |
lr_quadratic_warmup:
|
474 |
logging_steps:
|
475 |
+
save_strategy: # set to `no` to skip checkpoint saves
|
476 |
save_steps: # leave empty to save at each epoch
|
477 |
eval_steps:
|
478 |
save_total_limit: # checkpoints saved at a time
|
src/axolotl/utils/trainer.py
CHANGED
@@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
457 |
# we have an eval set, but no steps defined, use epoch
|
458 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
461 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
462 |
max_seq_length=cfg.sequence_len,
|
@@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
468 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
469 |
num_train_epochs=cfg.num_epochs,
|
470 |
learning_rate=cfg.learning_rate,
|
471 |
-
save_strategy="steps" if cfg.save_steps else "epoch",
|
472 |
save_steps=cfg.save_steps,
|
473 |
output_dir=cfg.output_dir,
|
474 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
|
|
457 |
# we have an eval set, but no steps defined, use epoch
|
458 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
459 |
|
460 |
+
if cfg.save_strategy:
|
461 |
+
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
462 |
+
else:
|
463 |
+
training_arguments_kwargs["save_strategy"] = (
|
464 |
+
"steps" if cfg.save_steps else "epoch",
|
465 |
+
)
|
466 |
+
|
467 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
468 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
469 |
max_seq_length=cfg.sequence_len,
|
|
|
475 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
476 |
num_train_epochs=cfg.num_epochs,
|
477 |
learning_rate=cfg.learning_rate,
|
|
|
478 |
save_steps=cfg.save_steps,
|
479 |
output_dir=cfg.output_dir,
|
480 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|