sentimentA / early_stopping.py
EATHARD's picture
Upload 6 files
d23d4f9 verified
import torch
class EarlyStopping:
"""Early stop the training if current metric is worse than the best one for longer than
number of wait_epochs or if metric stops changing.
Parameters
----------
wait_epochs: int, optional (default=2)
Number of epochs to wait to improve the metric to stop the training.
"""
def __init__(self, wait_epochs=2):
self.wait_epochs = wait_epochs
self.num_bad_scores = 0
self.num_const_scores = 0
self.best_score = None
self.best_metric = 0
def stop(self, metric, model, metric_type='better_decrease', delta=0.03):
"""Stop the training if metric criteria aren't met.
Parameters
----------
metric: float
Metric used to evaluate the validation performance.
model: pytorch model
Pytorch model instance.
metric_type: str, optional (default='better_decrease')
Specify the metric type, available options: better_decrease, better_increase.
delta: float, optional (default=0.03)
The minimum change of a metric that is considered in stoping decision.
Fraction of the metric.
Returns
-------
Boolean
True if training should be stoped, otherwise False.
"""
self.delta = delta
delta = self.delta * metric
if self.best_score is None:
self.best_score = metric
self.save_model_state(metric, model)
return False
if abs(metric - self.best_score) < self.delta/3 * metric:
self.num_const_scores += 1
if self.num_const_scores >= self.wait_epochs + 1:
print('\nTraining stoped by EarlyStopping')
return True
else:
self.num_const_scores = 0
if metric_type == 'better_decrease':
if metric > self.best_score + delta:
self.num_bad_scores += 1
elif metric > self.best_score:
self.num_bad_scores = 0
else:
self.best_score = metric
self.save_model_state(metric, model)
self.num_bad_scores = 0
else:
if metric < self.best_score - delta:
self.num_bad_scores += 1
elif metric < self.best_score:
self.num_bad_scores = 0
else:
self.best_score = metric
self.save_model_state(metric, model)
self.num_bad_scores = 0
if self.num_bad_scores >= self.wait_epochs:
print('\nTraining stoped by EarlyStopping')
return True
return False
def save_model_state(self, metric, model):
"""Saves the best model state.
"""
torch.save(model.state_dict(), 'checkpoint.pt')
self.best_metric = metric