make sure to save on the last step (#1615)
Browse files
src/axolotl/utils/callbacks/__init__.py
CHANGED
@@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
778 |
class SaveModelOnTrainEndCallback(TrainerCallback):
|
779 |
"""Callback to save model on train end"""
|
780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
def on_train_end( # pylint: disable=unused-argument
|
782 |
self, args, state, control, **kwargs
|
783 |
):
|
|
|
778 |
class SaveModelOnTrainEndCallback(TrainerCallback):
|
779 |
"""Callback to save model on train end"""
|
780 |
|
781 |
+
def on_step_end( # pylint: disable=unused-argument
|
782 |
+
self,
|
783 |
+
args: TrainingArguments,
|
784 |
+
state: TrainerState,
|
785 |
+
control: TrainerControl,
|
786 |
+
**kwargs,
|
787 |
+
):
|
788 |
+
# Save
|
789 |
+
if state.global_step >= state.max_steps:
|
790 |
+
control.should_save = True
|
791 |
+
|
792 |
def on_train_end( # pylint: disable=unused-argument
|
793 |
self, args, state, control, **kwargs
|
794 |
):
|