glenn-jocher commited on
Commit
ebafd1e
·
unverified ·
1 Parent(s): 26c3b11

single command --resume (#756)

Browse files

* single command --resume

* else check files, remove TODO

* argparse.Namespace()

* tensorboard lr

* bug fix in get_latest_run()

Files changed (2) hide show
  1. train.py +30 -25
  2. utils/general.py +1 -1
train.py CHANGED
@@ -42,7 +42,6 @@ def train(hyp, opt, device, tb_writer=None):
42
  epochs, batch_size, total_batch_size, weights, rank = \
43
  opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
44
 
45
- # TODO: Use DDP logging. Only the first process is allowed to log.
46
  # Save run settings
47
  with open(log_dir / 'hyp.yaml', 'w') as f:
48
  yaml.dump(hyp, f, sort_keys=False)
@@ -130,6 +129,8 @@ def train(hyp, opt, device, tb_writer=None):
130
 
131
  # Epochs
132
  start_epoch = ckpt['epoch'] + 1
 
 
133
  if epochs < start_epoch:
134
  logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
135
  (weights, ckpt['epoch'], epochs))
@@ -158,19 +159,19 @@ def train(hyp, opt, device, tb_writer=None):
158
  model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
159
 
160
  # Trainloader
161
- dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
162
- cache=opt.cache_images, rect=opt.rect, rank=rank,
163
  world_size=opt.world_size, workers=opt.workers)
164
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
165
  nb = len(dataloader) # number of batches
 
166
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
167
 
168
  # Testloader
169
  if rank in [-1, 0]:
170
- # local_rank is set to -1. Because only the first process is expected to do evaluation.
171
- testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
172
- cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size,
173
- workers=opt.workers)[0]
174
 
175
  # Model parameters
176
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -283,7 +284,7 @@ def train(hyp, opt, device, tb_writer=None):
283
  scaler.step(optimizer) # optimizer.step
284
  scaler.update()
285
  optimizer.zero_grad()
286
- if ema is not None:
287
  ema.update(model)
288
 
289
  # Print
@@ -305,12 +306,13 @@ def train(hyp, opt, device, tb_writer=None):
305
  # end batch ------------------------------------------------------------------------------------------------
306
 
307
  # Scheduler
 
308
  scheduler.step()
309
 
310
  # DDP process 0 or single-GPU
311
  if rank in [-1, 0]:
312
  # mAP
313
- if ema is not None:
314
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
315
  final_epoch = epoch + 1 == epochs
316
  if not opt.notest or final_epoch: # Calculate mAP
@@ -330,10 +332,11 @@ def train(hyp, opt, device, tb_writer=None):
330
 
331
  # Tensorboard
332
  if tb_writer:
333
- tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
334
  'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
335
- 'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
336
- for x, tag in zip(list(mloss[:-1]) + list(results), tags):
 
337
  tb_writer.add_scalar(tag, x, epoch)
338
 
339
  # Update best mAP
@@ -389,8 +392,7 @@ if __name__ == '__main__':
389
  parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
390
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
391
  parser.add_argument('--rect', action='store_true', help='rectangular training')
392
- parser.add_argument('--resume', nargs='?', const='get_last', default=False,
393
- help='resume from given path/last.pt, or most recent run if blank')
394
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
395
  parser.add_argument('--notest', action='store_true', help='only test final epoch')
396
  parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
@@ -413,21 +415,24 @@ if __name__ == '__main__':
413
  opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
414
  opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
415
  set_logging(opt.global_rank)
416
-
417
- # Resume
418
- if opt.resume:
419
- last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
420
- if last and not opt.weights:
421
- logger.info(f'Resuming training from {last}')
422
- opt.weights = last if opt.resume and not opt.weights else opt.weights
423
  if opt.global_rank in [-1, 0]:
424
  check_git_status()
425
 
426
- opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
427
- opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
428
- assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
 
 
 
 
 
 
 
 
 
 
 
429
 
430
- opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
431
  device = select_device(opt.device, batch_size=opt.batch_size)
432
 
433
  # DDP mode
 
42
  epochs, batch_size, total_batch_size, weights, rank = \
43
  opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
44
 
 
45
  # Save run settings
46
  with open(log_dir / 'hyp.yaml', 'w') as f:
47
  yaml.dump(hyp, f, sort_keys=False)
 
129
 
130
  # Epochs
131
  start_epoch = ckpt['epoch'] + 1
132
+ if opt.resume:
133
+ assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
134
  if epochs < start_epoch:
135
  logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
136
  (weights, ckpt['epoch'], epochs))
 
159
  model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
160
 
161
  # Trainloader
162
+ dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
163
+ hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
164
  world_size=opt.world_size, workers=opt.workers)
165
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
166
  nb = len(dataloader) # number of batches
167
+ ema.updates = start_epoch * nb // accumulate # set EMA updates
168
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
169
 
170
  # Testloader
171
  if rank in [-1, 0]:
172
+ testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
173
+ hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1,
174
+ world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0
 
175
 
176
  # Model parameters
177
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
 
284
  scaler.step(optimizer) # optimizer.step
285
  scaler.update()
286
  optimizer.zero_grad()
287
+ if ema:
288
  ema.update(model)
289
 
290
  # Print
 
306
  # end batch ------------------------------------------------------------------------------------------------
307
 
308
  # Scheduler
309
+ lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
310
  scheduler.step()
311
 
312
  # DDP process 0 or single-GPU
313
  if rank in [-1, 0]:
314
  # mAP
315
+ if ema:
316
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
317
  final_epoch = epoch + 1 == epochs
318
  if not opt.notest or final_epoch: # Calculate mAP
 
332
 
333
  # Tensorboard
334
  if tb_writer:
335
+ tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
336
  'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
337
+ 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
338
+ 'x/lr0', 'x/lr1', 'x/lr2'] # params
339
+ for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
340
  tb_writer.add_scalar(tag, x, epoch)
341
 
342
  # Update best mAP
 
392
  parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
393
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
394
  parser.add_argument('--rect', action='store_true', help='rectangular training')
395
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
 
396
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
397
  parser.add_argument('--notest', action='store_true', help='only test final epoch')
398
  parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
 
415
  opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
416
  opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
417
  set_logging(opt.global_rank)
 
 
 
 
 
 
 
418
  if opt.global_rank in [-1, 0]:
419
  check_git_status()
420
 
421
+ # Resume
422
+ if opt.resume: # resume an interrupted run
423
+ ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
424
+ assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
425
+ with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
426
+ opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
427
+ opt.cfg, opt.weights, opt.resume = '', ckpt, True
428
+ logger.info('Resuming training from %s' % ckpt)
429
+
430
+ else:
431
+ opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
432
+ opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
433
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
434
+ opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
435
 
 
436
  device = select_device(opt.device, batch_size=opt.batch_size)
437
 
438
  # DDP mode
utils/general.py CHANGED
@@ -61,7 +61,7 @@ def init_seeds(seed=0):
61
  def get_latest_run(search_dir='./runs'):
62
  # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
63
  last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
64
- return max(last_list, key=os.path.getctime)
65
 
66
 
67
  def check_git_status():
 
61
  def get_latest_run(search_dir='./runs'):
62
  # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
63
  last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
64
+ return max(last_list, key=os.path.getctime) if last_list else ''
65
 
66
 
67
  def check_git_status():