vaw2tmp / callbacks.py
nguyenvulebinh's picture
add config for training multi epochs
1e275bf
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging
class BreakEachEpoch(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
and checkpoints.
"""
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
control.should_training_stop = True
logging.get_logger().info("Break each epoch for reload new shard dataset")
return control