import numpy as np import torch from os.path import join import os class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0, save_path="."): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta self.save_path = save_path def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): '''Saves model when validation loss decrease.''' if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') # save_path = join(self.save_path, "best_model") # if not os.path.exists(save_path): # os.mkdir(save_path) # model_to_save = model.module if hasattr(model, 'module') else model # model_to_save.save_pretrained(save_path) self.val_loss_min = val_loss