Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from swift.utils import get_logger
logger = get_logger()
class EarlyStopCallback(TrainerCallback):
"""An early stop implementation"""
def __init__(self, total_interval=3):
self.best_metric = None
self.interval = 0
self.total_interval = total_interval
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
operator = np.greater if args.greater_is_better else np.less
if self.best_metric is None or operator(state.best_metric, self.best_metric):
self.best_metric = state.best_metric
else:
self.interval += 1
if self.interval >= self.total_interval:
logger.info(f'Training stop because of eval metric is stable at step {state.global_step}')
control.should_training_stop = True
extra_callbacks = []
# This example shows a simple example of EarlyStop Callback, uncomment this to use
# extra_callbacks = [EarlyStopCallback()]