glenn-jocher commited on
Commit
9368453
1 Parent(s): 886b984

train.py --logdir argparser addition (#660)

Browse files

* train.py --logdir argparser addition

* train.py --logdir argparser addition

Files changed (1) hide show
  1. train.py +10 -9
train.py CHANGED
@@ -55,20 +55,20 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
55
 
56
  def train(hyp, opt, device, tb_writer=None):
57
  print(f'Hyperparameters {hyp}')
58
- log_dir = tb_writer.log_dir if tb_writer else 'runs/evolve' # run directory
59
- wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
60
  os.makedirs(wdir, exist_ok=True)
61
  last = wdir + 'last.pt'
62
  best = wdir + 'best.pt'
63
- results_file = log_dir + os.sep + 'results.txt'
64
  epochs, batch_size, total_batch_size, weights, rank = \
65
  opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
66
 
67
  # TODO: Use DDP logging. Only the first process is allowed to log.
68
  # Save run settings
69
- with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
70
  yaml.dump(hyp, f, sort_keys=False)
71
- with open(Path(log_dir) / 'opt.yaml', 'w') as f:
72
  yaml.dump(vars(opt), f, sort_keys=False)
73
 
74
  # Configure
@@ -325,7 +325,7 @@ def train(hyp, opt, device, tb_writer=None):
325
 
326
  # Plot
327
  if ni < 3:
328
- f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename
329
  result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
330
  if tb_writer and result is not None:
331
  tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
@@ -433,7 +433,8 @@ if __name__ == '__main__':
433
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
434
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
435
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
436
- parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
 
437
  opt = parser.parse_args()
438
 
439
  # Resume
@@ -472,8 +473,8 @@ if __name__ == '__main__':
472
  if not opt.evolve:
473
  tb_writer = None
474
  if opt.global_rank in [-1, 0]:
475
- print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
476
- tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
477
 
478
  train(hyp, opt, device, tb_writer)
479
 
 
55
 
56
  def train(hyp, opt, device, tb_writer=None):
57
  print(f'Hyperparameters {hyp}')
58
+ log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
59
+ wdir = str(log_dir / 'weights') + os.sep # weights directory
60
  os.makedirs(wdir, exist_ok=True)
61
  last = wdir + 'last.pt'
62
  best = wdir + 'best.pt'
63
+ results_file = str(log_dir / 'results.txt')
64
  epochs, batch_size, total_batch_size, weights, rank = \
65
  opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
66
 
67
  # TODO: Use DDP logging. Only the first process is allowed to log.
68
  # Save run settings
69
+ with open(log_dir / 'hyp.yaml', 'w') as f:
70
  yaml.dump(hyp, f, sort_keys=False)
71
+ with open(log_dir / 'opt.yaml', 'w') as f:
72
  yaml.dump(vars(opt), f, sort_keys=False)
73
 
74
  # Configure
 
325
 
326
  # Plot
327
  if ni < 3:
328
+ f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
329
  result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
330
  if tb_writer and result is not None:
331
  tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
 
433
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
434
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
435
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
436
+ parser.add_argument('--local-rank', type=int, default=-1, help='DDP parameter, do not modify')
437
+ parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
438
  opt = parser.parse_args()
439
 
440
  # Resume
 
473
  if not opt.evolve:
474
  tb_writer = None
475
  if opt.global_rank in [-1, 0]:
476
+ print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
477
+ tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp
478
 
479
  train(hyp, opt, device, tb_writer)
480