glenn-jocher commited on
Commit
c923fbf
1 Parent(s): 1fc9d42

W&B artifacts feature addition (#1712)

Browse files
Files changed (1) hide show
  1. train.py +10 -5
train.py CHANGED
@@ -386,10 +386,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
386
 
387
  if rank in [-1, 0]:
388
  # Strip optimizers
 
389
  for f in [last, best]:
390
- if f.exists(): # is *.pt
391
- strip_optimizer(f) # strip optimizer
392
- os.system('gsutil cp %s gs://%s/weights' % (f, opt.bucket)) if opt.bucket else None # upload
 
393
 
394
  # Plots
395
  if plots:
@@ -398,9 +400,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
398
  files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
399
  wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
400
  if (save_dir / f).exists()]})
401
- logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
 
402
 
403
  # Test best.pt
 
404
  if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
405
  for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
406
  results, _, _ = test.test(opt.data,
@@ -408,7 +412,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
408
  imgsz=imgsz_test,
409
  conf_thres=conf,
410
  iou_thres=iou,
411
- model=attempt_load(best if best.exists() else last, device).half(),
412
  single_cls=opt.single_cls,
413
  dataloader=testloader,
414
  save_dir=save_dir,
@@ -448,6 +452,7 @@ if __name__ == '__main__':
448
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
449
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
450
  parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
 
451
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
452
  parser.add_argument('--project', default='runs/train', help='save to project/name')
453
  parser.add_argument('--name', default='exp', help='save to project/name')
 
386
 
387
  if rank in [-1, 0]:
388
  # Strip optimizers
389
+ final = best if best.exists() else last # final model
390
  for f in [last, best]:
391
+ if f.exists():
392
+ strip_optimizer(f) # strip optimizers
393
+ if opt.bucket:
394
+ os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
395
 
396
  # Plots
397
  if plots:
 
400
  files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
401
  wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
402
  if (save_dir / f).exists()]})
403
+ if opt.log_artifacts:
404
+ wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
405
 
406
  # Test best.pt
407
+ logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
408
  if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
409
  for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
410
  results, _, _ = test.test(opt.data,
 
412
  imgsz=imgsz_test,
413
  conf_thres=conf,
414
  iou_thres=iou,
415
+ model=attempt_load(final, device).half(),
416
  single_cls=opt.single_cls,
417
  dataloader=testloader,
418
  save_dir=save_dir,
 
452
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
453
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
454
  parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
455
+ parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
456
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
457
  parser.add_argument('--project', default='runs/train', help='save to project/name')
458
  parser.add_argument('--name', default='exp', help='save to project/name')