glenn-jocher commited on
Commit
a0ac5ad
1 Parent(s): 3c6e2f7

Single-source training update (#680)

Browse files
Files changed (1) hide show
  1. train.py +7 -6
train.py CHANGED
@@ -372,7 +372,7 @@ if __name__ == '__main__':
372
  parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
373
  parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
374
  parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
375
- parser.add_argument('--hyp', type=str, default='data/hyp.finetune.yaml', help='hyperparameters path')
376
  parser.add_argument('--epochs', type=int, default=300)
377
  parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
378
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
@@ -396,16 +396,17 @@ if __name__ == '__main__':
396
  opt = parser.parse_args()
397
 
398
  # Resume
399
- last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
400
- if last and not opt.weights:
401
- print(f'Resuming training from {last}')
402
- opt.weights = last if opt.resume and not opt.weights else opt.weights
 
403
  if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
404
  check_git_status()
405
 
 
406
  opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
407
  assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
408
- assert len(opt.hyp), '--hyp must be specified'
409
 
410
  opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
411
  device = select_device(opt.device, batch_size=opt.batch_size)
 
372
  parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
373
  parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
374
  parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
375
+ parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
376
  parser.add_argument('--epochs', type=int, default=300)
377
  parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
378
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
 
396
  opt = parser.parse_args()
397
 
398
  # Resume
399
+ if opt.resume:
400
+ last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
401
+ if last and not opt.weights:
402
+ print(f'Resuming training from {last}')
403
+ opt.weights = last if opt.resume and not opt.weights else opt.weights
404
  if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
405
  check_git_status()
406
 
407
+ opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
408
  opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
409
  assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
 
410
 
411
  opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
412
  device = select_device(opt.device, batch_size=opt.batch_size)