glenn-jocher commited on
Commit
b3e2f4e
1 Parent(s): fad27c0

Eliminate `total_batch_size` variable (#3697)

Browse files

* Eliminate `total_batch_size` variable

* cleanup

* Update train.py

Files changed (1) hide show
  1. train.py +12 -13
train.py CHANGED
@@ -46,10 +46,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
46
  opt,
47
  device,
48
  ):
49
- save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
50
- Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls
51
 
52
  # Directories
 
53
  wdir = save_dir / 'weights'
54
  wdir.mkdir(parents=True, exist_ok=True) # make dir
55
  last = wdir / 'last.pt'
@@ -127,8 +128,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
127
 
128
  # Optimizer
129
  nbs = 64 # nominal batch size
130
- accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
131
- hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
132
  logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
133
 
134
  pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
@@ -205,7 +206,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
205
  logger.info('Using SyncBatchNorm()')
206
 
207
  # Trainloader
208
- dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
209
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
210
  workers=opt.workers,
211
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
@@ -215,7 +216,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
215
 
216
  # Process 0
217
  if RANK in [-1, 0]:
218
- testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
219
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
220
  workers=opt.workers,
221
  pad=0.5, prefix=colorstr('val: '))[0]
@@ -302,7 +303,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
302
  if ni <= nw:
303
  xi = [0, nw] # x interp
304
  # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
305
- accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
306
  for j, x in enumerate(optimizer.param_groups):
307
  # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
308
  x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
371
  if not opt.notest or final_epoch: # Calculate mAP
372
  wandb_logger.current_epoch = epoch + 1
373
  results, maps, _ = test.test(data_dict,
374
- batch_size=batch_size * 2,
375
  imgsz=imgsz_test,
376
  model=ema.ema,
377
  single_cls=single_cls,
@@ -439,7 +440,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
439
  if is_coco: # COCO dataset
440
  for m in [last, best] if best.exists() else [last]: # speed, mAP tests
441
  results, _, _ = test.test(opt.data,
442
- batch_size=batch_size * 2,
443
  imgsz=imgsz_test,
444
  conf_thres=0.001,
445
  iou_thres=0.7,
@@ -518,7 +519,7 @@ def main(opt):
518
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
519
  with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
520
  opt = argparse.Namespace(**yaml.safe_load(f)) # replace
521
- opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
522
  logger.info('Resuming training from %s' % ckpt)
523
  else:
524
  # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
@@ -529,17 +530,15 @@ def main(opt):
529
  opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
530
 
531
  # DDP mode
532
- opt.total_batch_size = opt.batch_size
533
  device = select_device(opt.device, batch_size=opt.batch_size)
534
  if LOCAL_RANK != -1:
535
  from datetime import timedelta
536
- assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
537
  torch.cuda.set_device(LOCAL_RANK)
538
  device = torch.device('cuda', LOCAL_RANK)
539
  dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
540
  assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
541
  assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
542
- opt.batch_size = opt.total_batch_size // WORLD_SIZE
543
 
544
  # Train
545
  if not opt.evolve:
 
46
  opt,
47
  device,
48
  ):
49
+ save_dir, epochs, batch_size, weights, single_cls = \
50
+ opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
51
 
52
  # Directories
53
+ save_dir = Path(save_dir)
54
  wdir = save_dir / 'weights'
55
  wdir.mkdir(parents=True, exist_ok=True) # make dir
56
  last = wdir / 'last.pt'
 
128
 
129
  # Optimizer
130
  nbs = 64 # nominal batch size
131
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
132
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
133
  logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
134
 
135
  pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
 
206
  logger.info('Using SyncBatchNorm()')
207
 
208
  # Trainloader
209
+ dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
210
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
211
  workers=opt.workers,
212
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
 
216
 
217
  # Process 0
218
  if RANK in [-1, 0]:
219
+ testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
220
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
221
  workers=opt.workers,
222
  pad=0.5, prefix=colorstr('val: '))[0]
 
303
  if ni <= nw:
304
  xi = [0, nw] # x interp
305
  # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
306
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
307
  for j, x in enumerate(optimizer.param_groups):
308
  # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
309
  x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
 
372
  if not opt.notest or final_epoch: # Calculate mAP
373
  wandb_logger.current_epoch = epoch + 1
374
  results, maps, _ = test.test(data_dict,
375
+ batch_size=batch_size // WORLD_SIZE * 2,
376
  imgsz=imgsz_test,
377
  model=ema.ema,
378
  single_cls=single_cls,
 
440
  if is_coco: # COCO dataset
441
  for m in [last, best] if best.exists() else [last]: # speed, mAP tests
442
  results, _, _ = test.test(opt.data,
443
+ batch_size=batch_size // WORLD_SIZE * 2,
444
  imgsz=imgsz_test,
445
  conf_thres=0.001,
446
  iou_thres=0.7,
 
519
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
520
  with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
521
  opt = argparse.Namespace(**yaml.safe_load(f)) # replace
522
+ opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
523
  logger.info('Resuming training from %s' % ckpt)
524
  else:
525
  # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
 
530
  opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
531
 
532
  # DDP mode
 
533
  device = select_device(opt.device, batch_size=opt.batch_size)
534
  if LOCAL_RANK != -1:
535
  from datetime import timedelta
536
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
537
  torch.cuda.set_device(LOCAL_RANK)
538
  device = torch.device('cuda', LOCAL_RANK)
539
  dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
540
  assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
541
  assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
 
542
 
543
  # Train
544
  if not opt.evolve: