glenn-jocher commited on
Commit
5487451
1 Parent(s): f64fab5

EarlyStopper updates (#4679)

Browse files
Files changed (2) hide show
  1. train.py +3 -3
  2. utils/torch_utils.py +5 -2
train.py CHANGED
@@ -344,7 +344,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
344
  # mAP
345
  callbacks.on_train_epoch_end(epoch=epoch)
346
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
347
- final_epoch = epoch + 1 == epochs
348
  if not noval or final_epoch: # Calculate mAP
349
  results, maps, _ = val.run(data_dict,
350
  batch_size=batch_size // WORLD_SIZE * 2,
@@ -384,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
384
  callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
385
 
386
  # Stop Single-GPU
387
- if stopper(epoch=epoch, fitness=fi):
388
  break
389
 
390
  # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
@@ -462,7 +462,7 @@ def parse_opt(known=False):
462
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
463
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
464
  parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
465
- parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (epochs)')
466
  opt = parser.parse_known_args()[0] if known else parser.parse_args()
467
  return opt
468
 
 
344
  # mAP
345
  callbacks.on_train_epoch_end(epoch=epoch)
346
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
347
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
348
  if not noval or final_epoch: # Calculate mAP
349
  results, maps, _ = val.run(data_dict,
350
  batch_size=batch_size // WORLD_SIZE * 2,
 
384
  callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
385
 
386
  # Stop Single-GPU
387
+ if RANK == -1 and stopper(epoch=epoch, fitness=fi):
388
  break
389
 
390
  # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
 
462
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
463
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
464
  parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
465
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
466
  opt = parser.parse_known_args()[0] if known else parser.parse_args()
467
  return opt
468
 
utils/torch_utils.py CHANGED
@@ -298,13 +298,16 @@ class EarlyStopping:
298
  def __init__(self, patience=30):
299
  self.best_fitness = 0.0 # i.e. mAP
300
  self.best_epoch = 0
301
- self.patience = patience # epochs to wait after fitness stops improving to stop
 
302
 
303
  def __call__(self, epoch, fitness):
304
  if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
305
  self.best_epoch = epoch
306
  self.best_fitness = fitness
307
- stop = (epoch - self.best_epoch) >= self.patience # stop training if patience exceeded
 
 
308
  if stop:
309
  LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
310
  return stop
 
298
  def __init__(self, patience=30):
299
  self.best_fitness = 0.0 # i.e. mAP
300
  self.best_epoch = 0
301
+ self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
302
+ self.possible_stop = False # possible stop may occur next epoch
303
 
304
  def __call__(self, epoch, fitness):
305
  if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
306
  self.best_epoch = epoch
307
  self.best_fitness = fitness
308
+ delta = epoch - self.best_epoch # epochs without improvement
309
+ self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
310
+ stop = delta >= self.patience # stop training if patience exceeded
311
  if stop:
312
  LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
313
  return stop