glenn-jocher commited on
Commit
93cc015
1 Parent(s): 8b18b66

Add EarlyStopping feature (#4576)

Browse files

* Add EarlyStopping feature

* Add comment

* Cleanup

* Cleanup2

* debug

* debug2

* debug3

* debug3

* debug4

* debug5

* debug6

* debug7

* debug8

* debug9

* debug10

* debug11

* debug12

* Cleanup

* Add TODO for known DDP issue

Files changed (2) hide show
  1. train.py +18 -1
  2. utils/torch_utils.py +17 -0
train.py CHANGED
@@ -40,7 +40,8 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
40
  from utils.downloads import attempt_download
41
  from utils.loss import ComputeLoss
42
  from utils.plots import plot_labels, plot_evolve
43
- from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
 
44
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
45
  from utils.metrics import fitness
46
  from utils.loggers import Loggers
@@ -255,6 +256,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
255
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
256
  scheduler.last_epoch = start_epoch - 1 # do not move
257
  scaler = amp.GradScaler(enabled=cuda)
 
258
  compute_loss = ComputeLoss(model) # init loss class
259
  LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
260
  f'Using {train_loader.num_workers} dataloader workers\n'
@@ -389,6 +391,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
389
  del ckpt
390
  callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  # end epoch ----------------------------------------------------------------------------------------------------
393
  # end training -----------------------------------------------------------------------------------------------------
394
  if RANK in [-1, 0]:
@@ -454,6 +470,7 @@ def parse_opt(known=False):
454
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
455
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
456
  parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
 
457
  opt = parser.parse_known_args()[0] if known else parser.parse_args()
458
  return opt
459
 
 
40
  from utils.downloads import attempt_download
41
  from utils.loss import ComputeLoss
42
  from utils.plots import plot_labels, plot_evolve
43
+ from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
44
+ torch_distributed_zero_first
45
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
46
  from utils.metrics import fitness
47
  from utils.loggers import Loggers
 
256
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
257
  scheduler.last_epoch = start_epoch - 1 # do not move
258
  scaler = amp.GradScaler(enabled=cuda)
259
+ stopper = EarlyStopping(patience=opt.patience)
260
  compute_loss = ComputeLoss(model) # init loss class
261
  LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
262
  f'Using {train_loader.num_workers} dataloader workers\n'
 
391
  del ckpt
392
  callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
393
 
394
+ # Stop Single-GPU
395
+ if stopper(epoch=epoch, fitness=fi):
396
+ break
397
+
398
+ # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
399
+ # stop = stopper(epoch=epoch, fitness=fi)
400
+ # if RANK == 0:
401
+ # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
402
+
403
+ # Stop DPP
404
+ # with torch_distributed_zero_first(RANK):
405
+ # if stop:
406
+ # break # must break all DDP ranks
407
+
408
  # end epoch ----------------------------------------------------------------------------------------------------
409
  # end training -----------------------------------------------------------------------------------------------------
410
  if RANK in [-1, 0]:
 
470
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
471
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
472
  parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
473
+ parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (epochs)')
474
  opt = parser.parse_known_args()[0] if known else parser.parse_args()
475
  return opt
476
 
utils/torch_utils.py CHANGED
@@ -293,6 +293,23 @@ def copy_attr(a, b, include=(), exclude=()):
293
  setattr(a, k, v)
294
 
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  class ModelEMA:
297
  """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
298
  Keep a moving average of everything in the model state_dict (parameters and buffers).
 
293
  setattr(a, k, v)
294
 
295
 
296
+ class EarlyStopping:
297
+ # YOLOv5 simple early stopper
298
+ def __init__(self, patience=30):
299
+ self.best_fitness = 0.0 # i.e. mAP
300
+ self.best_epoch = 0
301
+ self.patience = patience # epochs to wait after fitness stops improving to stop
302
+
303
+ def __call__(self, epoch, fitness):
304
+ if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
305
+ self.best_epoch = epoch
306
+ self.best_fitness = fitness
307
+ stop = (epoch - self.best_epoch) >= self.patience # stop training if patience exceeded
308
+ if stop:
309
+ LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
310
+ return stop
311
+
312
+
313
  class ModelEMA:
314
  """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
315
  Keep a moving average of everything in the model state_dict (parameters and buffers).