| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import numpy as np | |
| from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments | |
| from swift.utils import get_logger | |
| logger = get_logger() | |
| class EarlyStopCallback(TrainerCallback): | |
| """An early stop implementation""" | |
| def __init__(self, total_interval=3): | |
| self.best_metric = None | |
| self.interval = 0 | |
| self.total_interval = total_interval | |
| def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| operator = np.greater if args.greater_is_better else np.less | |
| if self.best_metric is None or operator(state.best_metric, self.best_metric): | |
| self.best_metric = state.best_metric | |
| else: | |
| self.interval += 1 | |
| if self.interval >= self.total_interval: | |
| logger.info(f'Training stop because of eval metric is stable at step {state.global_step}') | |
| control.should_training_stop = True | |
| extra_callbacks = [] | |
| # This example shows a simple example of EarlyStop Callback, uncomment this to use | |
| # extra_callbacks = [EarlyStopCallback()] | |