File size: 7,357 Bytes
cc9dfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Contribution from @fredguth, https://github.com/fredguth/fastai_playground.
from fastai.torch_core import *
from fastai.callback import *
from fastai.basic_train import *
__all__ = ['TerminateOnNaNCallback', 'EarlyStoppingCallback', 'SaveModelCallback', 'TrackerCallback',
'ReduceLROnPlateauCallback', 'TrackEpochCallback' ]
class TerminateOnNaNCallback(Callback):
"A `Callback` that terminates training if loss is NaN."
def __init__(self):
self.stop = False
def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:
"Test if `last_loss` is NaN and interrupts training."
if self.stop: return True #to skip validation after stopping during training
if torch.isnan(last_loss):
print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
return {'stop_epoch': True, 'stop_training': True, 'skip_validate': True}
class TrackerCallback(LearnerCallback):
"A `LearnerCallback` that keeps track of the best value in `monitor`."
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto'):
super().__init__(learn)
self.monitor,self.mode = monitor,mode
if self.mode not in ['auto', 'min', 'max']:
warn(f'{self.__class__} mode {self.mode} is invalid, falling back to "auto" mode.')
self.mode = 'auto'
mode_dict = {'min': np.less, 'max':np.greater}
mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater
self.operator = mode_dict[self.mode]
def on_train_begin(self, **kwargs:Any)->None:
"Initializes the best value."
self.best = float('inf') if self.operator == np.less else -float('inf')
def get_monitor_value(self):
"Pick the monitored value."
if self.monitor=='trn_loss' and len(self.learn.recorder.losses) == 0: return None
elif len(self.learn.recorder.val_losses) == 0: return None
values = {'train_loss':self.learn.recorder.losses[-1].cpu().numpy(),
'valid_loss':self.learn.recorder.val_losses[-1]}
if values['valid_loss'] is None: return
if self.learn.recorder.metrics:
for m, n in zip(self.learn.recorder.metrics[-1],self.learn.recorder.names[3:-1]):
values[n] = m
if values.get(self.monitor) is None:
warn(f'{self.__class__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, self.learn.recorder.names[1:-1]))}')
return values.get(self.monitor)
class EarlyStoppingCallback(TrackerCallback):
"A `TrackerCallback` that terminates training when monitored quantity stops improving."
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', min_delta:int=0, patience:int=0):
super().__init__(learn, monitor=monitor, mode=mode)
self.min_delta,self.patience = min_delta,patience
if self.operator == np.less: self.min_delta *= -1
def on_train_begin(self, **kwargs:Any)->None:
"Initialize inner arguments."
self.wait = 0
super().on_train_begin(**kwargs)
def on_epoch_end(self, epoch, **kwargs:Any)->None:
"Compare the value monitored to its best score and maybe stop training."
current = self.get_monitor_value()
if current is None: return
if self.operator(current - self.min_delta, self.best):
self.best,self.wait = current,0
else:
self.wait += 1
if self.wait > self.patience:
print(f'Epoch {epoch}: early stopping')
return {"stop_training":True}
class SaveModelCallback(TrackerCallback):
"A `TrackerCallback` that saves the model when monitored quantity is best."
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str='bestmodel'):
super().__init__(learn, monitor=monitor, mode=mode)
self.every,self.name = every,name
if self.every not in ['improvement', 'epoch']:
warn(f'SaveModel every {self.every} is invalid, falling back to "improvement".')
self.every = 'improvement'
def jump_to_epoch(self, epoch:int)->None:
try:
self.learn.load(f'{self.name}_{epoch-1}', purge=False)
print(f"Loaded {self.name}_{epoch-1}")
except: print(f'Model {self.name}_{epoch-1} not found.')
def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
"Compare the value monitored to its best score and maybe save the model."
if self.every=="epoch": self.learn.save(f'{self.name}_{epoch}')
else: #every="improvement"
current = self.get_monitor_value()
if current is not None and self.operator(current, self.best):
print(f'Better model found at epoch {epoch} with {self.monitor} value: {current}.')
self.best = current
self.learn.save(f'{self.name}')
def on_train_end(self, **kwargs):
"Load the best model."
if self.every=="improvement" and (self.learn.path/f'{self.learn.model_dir}/{self.name}.pth').is_file():
self.learn.load(f'{self.name}', purge=False)
class ReduceLROnPlateauCallback(TrackerCallback):
"A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', patience:int=0, factor:float=0.2,
min_delta:int=0):
super().__init__(learn, monitor=monitor, mode=mode)
self.patience,self.factor,self.min_delta = patience,factor,min_delta
if self.operator == np.less: self.min_delta *= -1
def on_train_begin(self, **kwargs:Any)->None:
"Initialize inner arguments."
self.wait, self.opt = 0, self.learn.opt
super().on_train_begin(**kwargs)
def on_epoch_end(self, epoch, **kwargs:Any)->None:
"Compare the value monitored to its best and maybe reduce lr."
current = self.get_monitor_value()
if current is None: return
if self.operator(current - self.min_delta, self.best): self.best,self.wait = current,0
else:
self.wait += 1
if self.wait > self.patience:
self.opt.lr *= self.factor
self.wait = 0
print(f'Epoch {epoch}: reducing lr to {self.opt.lr}')
class TrackEpochCallback(LearnerCallback):
_order = -20 #Need to run before fit_one_cycle
def __init__(self, learn:Learner, name:str='epoch', epoch_offset:int=None):
"Store completed epoch number in `learn.model_dir/name`."
super().__init__(learn)
learn._test_writeable_path()
self.path = learn.path/learn.model_dir/name
if epoch_offset is None:
if os.path.isfile(self.path):
with self.path.open('r') as f:
try: self.start_epoch = int(f.read())+1
except: self.start_epoch = 0
else: self.start_epoch = 0
def on_train_begin(self, **kwargs:Any):
return {'epoch': self.start_epoch}
def on_epoch_end(self, epoch, **kwargs:Any)->None:
with self.path.open('w') as f: f.write(f'{epoch}')
def restart(self): os.remove(self.path)
|