glenn-jocher commited on
Commit
69ff781
·
1 Parent(s): 987c226

opt.img_weights bug fix (#885)

Browse files
Files changed (1) hide show
  1. train.py +8 -10
train.py CHANGED
@@ -216,18 +216,15 @@ def train(hyp, opt, device, tb_writer=None):
216
  model.train()
217
 
218
  # Update image weights (optional)
219
- if dataset.image_weights:
220
  # Generate indices
221
  if rank in [-1, 0]:
222
- w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
223
- image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
224
- dataset.indices = random.choices(range(dataset.n), weights=image_weights,
225
- k=dataset.n) # rand weighted idx
226
  # Broadcast if DDP
227
  if rank != -1:
228
- indices = torch.zeros([dataset.n], dtype=torch.int)
229
- if rank == 0:
230
- indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
231
  dist.broadcast(indices, 0)
232
  if rank != 0:
233
  dataset.indices = indices.cpu().numpy()
@@ -388,7 +385,8 @@ if __name__ == '__main__':
388
  parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
389
  parser.add_argument('--epochs', type=int, default=300)
390
  parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
391
- parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
 
392
  parser.add_argument('--rect', action='store_true', help='rectangular training')
393
  parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
394
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
@@ -471,7 +469,7 @@ if __name__ == '__main__':
471
  'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
472
  'iou_t': (0, 0.1, 0.7), # IoU training threshold
473
  'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
474
- # 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
475
  'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
476
  'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
477
  'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
 
216
  model.train()
217
 
218
  # Update image weights (optional)
219
+ if opt.img_weights:
220
  # Generate indices
221
  if rank in [-1, 0]:
222
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
223
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
224
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
 
225
  # Broadcast if DDP
226
  if rank != -1:
227
+ indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
 
 
228
  dist.broadcast(indices, 0)
229
  if rank != 0:
230
  dataset.indices = indices.cpu().numpy()
 
385
  parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
386
  parser.add_argument('--epochs', type=int, default=300)
387
  parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
388
+ parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
389
+ parser.add_argument('--img-weights', action='store_true', help='use weighted image selection for training')
390
  parser.add_argument('--rect', action='store_true', help='rectangular training')
391
  parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
392
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
 
469
  'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
470
  'iou_t': (0, 0.1, 0.7), # IoU training threshold
471
  'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
472
+ # 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
473
  'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
474
  'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
475
  'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)