glenn-jocher commited on
Commit
958ab92
1 Parent(s): 0cfc5b2

Remove `opt` from `create_dataloader()`` (#3552)

Browse files
Files changed (3) hide show
  1. test.py +1 -1
  2. train.py +9 -8
  3. utils/datasets.py +3 -3
test.py CHANGED
@@ -88,7 +88,7 @@ def test(data,
88
  if device.type != 'cpu':
89
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
90
  task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
91
- dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True,
92
  prefix=colorstr(f'{task}: '))[0]
93
 
94
  seen = 0
 
88
  if device.type != 'cpu':
89
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
90
  task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
91
+ dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
92
  prefix=colorstr(f'{task}: '))[0]
93
 
94
  seen = 0
train.py CHANGED
@@ -41,8 +41,9 @@ logger = logging.getLogger(__name__)
41
 
42
  def train(hyp, opt, device, tb_writer=None):
43
  logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
44
- save_dir, epochs, batch_size, total_batch_size, weights, rank = \
45
- Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
 
46
 
47
  # Directories
48
  wdir = save_dir / 'weights'
@@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None):
75
  if wandb_logger.wandb:
76
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
77
 
78
- nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
79
- names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
80
  assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
81
  is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
82
 
@@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None):
187
  logger.info('Using SyncBatchNorm()')
188
 
189
  # Trainloader
190
- dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
191
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
192
  world_size=opt.world_size, workers=opt.workers,
193
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
@@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None):
197
 
198
  # Process 0
199
  if rank in [-1, 0]:
200
- testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
201
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
202
  world_size=opt.world_size, workers=opt.workers,
203
  pad=0.5, prefix=colorstr('val: '))[0]
@@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
357
  batch_size=batch_size * 2,
358
  imgsz=imgsz_test,
359
  model=ema.ema,
360
- single_cls=opt.single_cls,
361
  dataloader=testloader,
362
  save_dir=save_dir,
363
  save_json=is_coco and final_epoch,
@@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None):
429
  conf_thres=0.001,
430
  iou_thres=0.7,
431
  model=attempt_load(m, device).half(),
432
- single_cls=opt.single_cls,
433
  dataloader=testloader,
434
  save_dir=save_dir,
435
  save_json=True,
 
41
 
42
  def train(hyp, opt, device, tb_writer=None):
43
  logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
44
+ save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
45
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
46
+ opt.single_cls
47
 
48
  # Directories
49
  wdir = save_dir / 'weights'
 
76
  if wandb_logger.wandb:
77
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
78
 
79
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
80
+ names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
81
  assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
82
  is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
83
 
 
188
  logger.info('Using SyncBatchNorm()')
189
 
190
  # Trainloader
191
+ dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
192
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
193
  world_size=opt.world_size, workers=opt.workers,
194
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
 
198
 
199
  # Process 0
200
  if rank in [-1, 0]:
201
+ testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
202
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
203
  world_size=opt.world_size, workers=opt.workers,
204
  pad=0.5, prefix=colorstr('val: '))[0]
 
358
  batch_size=batch_size * 2,
359
  imgsz=imgsz_test,
360
  model=ema.ema,
361
+ single_cls=single_cls,
362
  dataloader=testloader,
363
  save_dir=save_dir,
364
  save_json=is_coco and final_epoch,
 
430
  conf_thres=0.001,
431
  iou_thres=0.7,
432
  model=attempt_load(m, device).half(),
433
+ single_cls=single_cls,
434
  dataloader=testloader,
435
  save_dir=save_dir,
436
  save_json=True,
utils/datasets.py CHANGED
@@ -62,8 +62,8 @@ def exif_size(img):
62
  return s
63
 
64
 
65
- def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
66
- rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
67
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
68
  with torch_distributed_zero_first(rank):
69
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@@ -71,7 +71,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
71
  hyp=hyp, # augmentation hyperparameters
72
  rect=rect, # rectangular training
73
  cache_images=cache,
74
- single_cls=opt.single_cls,
75
  stride=int(stride),
76
  pad=pad,
77
  image_weights=image_weights,
 
62
  return s
63
 
64
 
65
+ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
66
+ rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
67
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
68
  with torch_distributed_zero_first(rank):
69
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
 
71
  hyp=hyp, # augmentation hyperparameters
72
  rect=rect, # rectangular training
73
  cache_images=cache,
74
+ single_cls=single_cls,
75
  stride=int(stride),
76
  pad=pad,
77
  image_weights=image_weights,