from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging class BreakEachEpoch(TrainerCallback): """ A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation and checkpoints. """ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): control.should_training_stop = True logging.get_logger().info("Break each epoch for reload new shard dataset") return control