| from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging | |
| class BreakEachEpoch(TrainerCallback): | |
| """ | |
| A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation | |
| and checkpoints. | |
| """ | |
| def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| control.should_training_stop = True | |
| logging.get_logger().info("Break each epoch for reload new shard dataset") | |
| return control | |