glenn-jocher commited on
Commit
fbf41e0
·
unverified ·
1 Parent(s): c1af67d

Add `train.run()` method (#3700)

Browse files

* Update train.py explicit arguments

* Update train.py

* Add run method

Files changed (1) hide show
  1. train.py +45 -36
train.py CHANGED
@@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
46
  opt,
47
  device,
48
  ):
49
- save_dir, epochs, batch_size, weights, single_cls = \
50
- opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
 
51
 
52
  # Directories
53
  save_dir = Path(save_dir)
@@ -70,34 +71,34 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
70
  yaml.safe_dump(vars(opt), f, sort_keys=False)
71
 
72
  # Configure
73
- plots = not opt.evolve # create plots
74
  cuda = device.type != 'cpu'
75
  init_seeds(2 + RANK)
76
- with open(opt.data) as f:
77
  data_dict = yaml.safe_load(f) # data dict
78
 
79
  # Loggers
80
  loggers = {'wandb': None, 'tb': None} # loggers dict
81
  if RANK in [-1, 0]:
82
  # TensorBoard
83
- if not opt.evolve:
84
  prefix = colorstr('tensorboard: ')
85
  logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
86
- loggers['tb'] = SummaryWriter(opt.save_dir)
87
 
88
  # W&B
89
  opt.hyp = hyp # add hyperparameters
90
  run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
91
  wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
92
  loggers['wandb'] = wandb_logger.wandb
93
- data_dict = wandb_logger.data_dict
94
- if wandb_logger.wandb:
95
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
96
 
97
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
98
  names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
99
- assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
100
- is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
101
 
102
  # Model
103
  pretrained = weights.endswith('.pt')
@@ -105,14 +106,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
105
  with torch_distributed_zero_first(RANK):
106
  weights = attempt_download(weights) # download if not found locally
107
  ckpt = torch.load(weights, map_location=device) # load checkpoint
108
- model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
109
- exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
110
  state_dict = ckpt['model'].float().state_dict() # to FP32
111
  state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
112
  model.load_state_dict(state_dict, strict=False) # load
113
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
114
  else:
115
- model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
116
  with torch_distributed_zero_first(RANK):
117
  check_dataset(data_dict) # check
118
  train_path = data_dict['train']
@@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
182
 
183
  # Epochs
184
  start_epoch = ckpt['epoch'] + 1
185
- if opt.resume:
186
  assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
187
  if epochs < start_epoch:
188
  logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
@@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
210
  # Trainloader
211
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
212
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
213
- workers=opt.workers,
214
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
215
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
216
  nb = len(dataloader) # number of batches
217
- assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
218
 
219
  # Process 0
220
  if RANK in [-1, 0]:
221
  testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
222
- hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
223
- workers=opt.workers,
224
  pad=0.5, prefix=colorstr('val: '))[0]
225
 
226
- if not opt.resume:
227
  labels = np.concatenate(dataset.labels, 0)
228
  c = torch.tensor(labels[:, 0]) # classes
229
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
@@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
356
  with warnings.catch_warnings():
357
  warnings.simplefilter('ignore') # suppress jit trace warning
358
  loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
359
- elif plots and ni == 10 and wandb_logger.wandb:
360
- wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
361
  save_dir.glob('train*.jpg') if x.exists()]})
362
 
363
  # end batch ------------------------------------------------------------------------------------------------
@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
371
  # mAP
372
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
373
  final_epoch = epoch + 1 == epochs
374
- if not opt.notest or final_epoch: # Calculate mAP
375
  wandb_logger.current_epoch = epoch + 1
376
  results, maps, _ = test.test(data_dict,
377
  batch_size=batch_size // WORLD_SIZE * 2,
@@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
398
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
399
  if loggers['tb']:
400
  loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
401
- if wandb_logger.wandb:
402
  wandb_logger.log({tag: x}) # W&B
403
 
404
  # Update best mAP
@@ -408,7 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
408
  wandb_logger.end_epoch(best_result=best_fitness == fi)
409
 
410
  # Save model
411
- if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
412
  ckpt = {'epoch': epoch,
413
  'best_fitness': best_fitness,
414
  'training_results': results_file.read_text(),
@@ -416,13 +417,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
416
  'ema': deepcopy(ema.ema).half(),
417
  'updates': ema.updates,
418
  'optimizer': optimizer.state_dict(),
419
- 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
420
 
421
  # Save last, best and delete
422
  torch.save(ckpt, last)
423
  if best_fitness == fi:
424
  torch.save(ckpt, best)
425
- if wandb_logger.wandb:
426
  if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
427
  wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
428
  del ckpt
@@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
433
  logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
434
  if plots:
435
  plot_results(save_dir=save_dir) # save as results.png
436
- if wandb_logger.wandb:
437
  files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
438
- wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
439
  if (save_dir / f).exists()]})
440
 
441
- if not opt.evolve:
442
  if is_coco: # COCO dataset
443
  for m in [last, best] if best.exists() else [last]: # speed, mAP tests
444
- results, _, _ = test.test(opt.data,
445
  batch_size=batch_size // WORLD_SIZE * 2,
446
  imgsz=imgsz_test,
447
  conf_thres=0.001,
@@ -457,17 +458,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
457
  for f in last, best:
458
  if f.exists():
459
  strip_optimizer(f) # strip optimizers
460
- if wandb_logger.wandb: # Log the stripped model
461
- wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
462
- name='run_' + wandb_logger.wandb_run.id + '_model',
463
- aliases=['latest', 'best', 'stripped'])
464
  wandb_logger.finish_run()
465
 
466
  torch.cuda.empty_cache()
467
  return results
468
 
469
 
470
- def parse_opt():
471
  parser = argparse.ArgumentParser()
472
  parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
473
  parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
@@ -503,7 +504,7 @@ def parse_opt():
503
  parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
504
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
505
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
506
- opt = parser.parse_args()
507
  return opt
508
 
509
 
@@ -633,6 +634,14 @@ def main(opt):
633
  f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
634
 
635
 
 
 
 
 
 
 
 
 
636
  if __name__ == "__main__":
637
  opt = parse_opt()
638
  main(opt)
 
46
  opt,
47
  device,
48
  ):
49
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
50
+ opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
51
+ opt.resume, opt.notest, opt.nosave, opt.workers
52
 
53
  # Directories
54
  save_dir = Path(save_dir)
 
71
  yaml.safe_dump(vars(opt), f, sort_keys=False)
72
 
73
  # Configure
74
+ plots = not evolve # create plots
75
  cuda = device.type != 'cpu'
76
  init_seeds(2 + RANK)
77
+ with open(data) as f:
78
  data_dict = yaml.safe_load(f) # data dict
79
 
80
  # Loggers
81
  loggers = {'wandb': None, 'tb': None} # loggers dict
82
  if RANK in [-1, 0]:
83
  # TensorBoard
84
+ if not evolve:
85
  prefix = colorstr('tensorboard: ')
86
  logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
87
+ loggers['tb'] = SummaryWriter(str(save_dir))
88
 
89
  # W&B
90
  opt.hyp = hyp # add hyperparameters
91
  run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
92
  wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
93
  loggers['wandb'] = wandb_logger.wandb
94
+ if loggers['wandb']:
95
+ data_dict = wandb_logger.data_dict
96
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
97
 
98
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
99
  names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
100
+ assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
101
+ is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
102
 
103
  # Model
104
  pretrained = weights.endswith('.pt')
 
106
  with torch_distributed_zero_first(RANK):
107
  weights = attempt_download(weights) # download if not found locally
108
  ckpt = torch.load(weights, map_location=device) # load checkpoint
109
+ model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
110
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
111
  state_dict = ckpt['model'].float().state_dict() # to FP32
112
  state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
113
  model.load_state_dict(state_dict, strict=False) # load
114
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
115
  else:
116
+ model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
117
  with torch_distributed_zero_first(RANK):
118
  check_dataset(data_dict) # check
119
  train_path = data_dict['train']
 
183
 
184
  # Epochs
185
  start_epoch = ckpt['epoch'] + 1
186
+ if resume:
187
  assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
188
  if epochs < start_epoch:
189
  logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
 
211
  # Trainloader
212
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
213
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
214
+ workers=workers,
215
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
216
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
217
  nb = len(dataloader) # number of batches
218
+ assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
219
 
220
  # Process 0
221
  if RANK in [-1, 0]:
222
  testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
223
+ hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
224
+ workers=workers,
225
  pad=0.5, prefix=colorstr('val: '))[0]
226
 
227
+ if not resume:
228
  labels = np.concatenate(dataset.labels, 0)
229
  c = torch.tensor(labels[:, 0]) # classes
230
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
 
357
  with warnings.catch_warnings():
358
  warnings.simplefilter('ignore') # suppress jit trace warning
359
  loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
360
+ elif plots and ni == 10 and loggers['wandb']:
361
+ wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
362
  save_dir.glob('train*.jpg') if x.exists()]})
363
 
364
  # end batch ------------------------------------------------------------------------------------------------
 
372
  # mAP
373
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
374
  final_epoch = epoch + 1 == epochs
375
+ if not notest or final_epoch: # Calculate mAP
376
  wandb_logger.current_epoch = epoch + 1
377
  results, maps, _ = test.test(data_dict,
378
  batch_size=batch_size // WORLD_SIZE * 2,
 
399
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
400
  if loggers['tb']:
401
  loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
402
+ if loggers['wandb']:
403
  wandb_logger.log({tag: x}) # W&B
404
 
405
  # Update best mAP
 
409
  wandb_logger.end_epoch(best_result=best_fitness == fi)
410
 
411
  # Save model
412
+ if (not nosave) or (final_epoch and not evolve): # if save
413
  ckpt = {'epoch': epoch,
414
  'best_fitness': best_fitness,
415
  'training_results': results_file.read_text(),
 
417
  'ema': deepcopy(ema.ema).half(),
418
  'updates': ema.updates,
419
  'optimizer': optimizer.state_dict(),
420
+ 'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
421
 
422
  # Save last, best and delete
423
  torch.save(ckpt, last)
424
  if best_fitness == fi:
425
  torch.save(ckpt, best)
426
+ if loggers['wandb']:
427
  if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
428
  wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
429
  del ckpt
 
434
  logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
435
  if plots:
436
  plot_results(save_dir=save_dir) # save as results.png
437
+ if loggers['wandb']:
438
  files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
439
+ wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
440
  if (save_dir / f).exists()]})
441
 
442
+ if not evolve:
443
  if is_coco: # COCO dataset
444
  for m in [last, best] if best.exists() else [last]: # speed, mAP tests
445
+ results, _, _ = test.test(data,
446
  batch_size=batch_size // WORLD_SIZE * 2,
447
  imgsz=imgsz_test,
448
  conf_thres=0.001,
 
458
  for f in last, best:
459
  if f.exists():
460
  strip_optimizer(f) # strip optimizers
461
+ if loggers['wandb']: # Log the stripped model
462
+ loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
463
+ name='run_' + wandb_logger.wandb_run.id + '_model',
464
+ aliases=['latest', 'best', 'stripped'])
465
  wandb_logger.finish_run()
466
 
467
  torch.cuda.empty_cache()
468
  return results
469
 
470
 
471
+ def parse_opt(known=False):
472
  parser = argparse.ArgumentParser()
473
  parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
474
  parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
 
504
  parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
505
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
506
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
507
+ opt = parser.parse_known_args()[0] if known else parser.parse_args()
508
  return opt
509
 
510
 
 
634
  f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
635
 
636
 
637
+ def run(**kwargs):
638
+ # Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
639
+ opt = parse_opt(True)
640
+ for k, v in kwargs.items():
641
+ setattr(opt, k, v)
642
+ main(opt)
643
+
644
+
645
  if __name__ == "__main__":
646
  opt = parse_opt()
647
  main(opt)