glenn-jocher commited on
Commit
b569ed6
1 Parent(s): 07a82f4

pretrained model loading bug fix (#450)

Browse files

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (1) hide show
  1. train.py +14 -11
train.py CHANGED
@@ -1,13 +1,12 @@
1
  import argparse
2
 
3
- import torch
4
  import torch.distributed as dist
5
  import torch.nn.functional as F
6
  import torch.optim as optim
7
  import torch.optim.lr_scheduler as lr_scheduler
8
  import torch.utils.data
9
- from torch.utils.tensorboard import SummaryWriter
10
  from torch.nn.parallel import DistributedDataParallel as DDP
 
11
 
12
  import test # import test.py to get mAP after each epoch
13
  from models.yolo import Model
@@ -61,7 +60,7 @@ def train(hyp, tb_writer, opt, device):
61
  yaml.dump(vars(opt), f, sort_keys=False)
62
 
63
  epochs = opt.epochs # 300
64
- batch_size = opt.batch_size # batch size per process.
65
  total_batch_size = opt.total_batch_size
66
  weights = opt.weights # initial training weights
67
  local_rank = opt.local_rank
@@ -70,7 +69,7 @@ def train(hyp, tb_writer, opt, device):
70
  # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
71
 
72
  # Configure
73
- init_seeds(2+local_rank)
74
  with open(opt.data) as f:
75
  data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
76
  train_path = data_dict['train']
@@ -131,7 +130,8 @@ def train(hyp, tb_writer, opt, device):
131
 
132
  # load model
133
  try:
134
- ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() if k in model.state_dict()}
 
135
  model.load_state_dict(ckpt['model'], strict=False)
136
  except KeyError as e:
137
  s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
@@ -187,7 +187,8 @@ def train(hyp, tb_writer, opt, device):
187
 
188
  # Trainloader
189
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
190
- cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size)
 
191
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
192
  nb = len(dataloader) # number of batches
193
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
@@ -195,8 +196,8 @@ def train(hyp, tb_writer, opt, device):
195
  # Testloader
196
  if local_rank in [-1, 0]:
197
  # local_rank is set to -1. Because only the first process is expected to do evaluation.
198
- testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
199
- cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
200
 
201
  # Model parameters
202
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -242,7 +243,8 @@ def train(hyp, tb_writer, opt, device):
242
  if local_rank in [-1, 0]:
243
  w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
244
  image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
245
- dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
 
246
  # Broadcast.
247
  if local_rank != -1:
248
  indices = torch.zeros([dataset.n], dtype=torch.int)
@@ -402,7 +404,7 @@ def train(hyp, tb_writer, opt, device):
402
  plot_results() # save as results.png
403
  print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
404
 
405
- dist.destroy_process_group() if local_rank not in [-1,0] else None
406
  torch.cuda.empty_cache()
407
  return results
408
 
@@ -431,7 +433,8 @@ if __name__ == '__main__':
431
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
432
  parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
433
  # Parameter For DDP.
434
- parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
 
435
  opt = parser.parse_args()
436
 
437
  last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
 
1
  import argparse
2
 
 
3
  import torch.distributed as dist
4
  import torch.nn.functional as F
5
  import torch.optim as optim
6
  import torch.optim.lr_scheduler as lr_scheduler
7
  import torch.utils.data
 
8
  from torch.nn.parallel import DistributedDataParallel as DDP
9
+ from torch.utils.tensorboard import SummaryWriter
10
 
11
  import test # import test.py to get mAP after each epoch
12
  from models.yolo import Model
 
60
  yaml.dump(vars(opt), f, sort_keys=False)
61
 
62
  epochs = opt.epochs # 300
63
+ batch_size = opt.batch_size # batch size per process.
64
  total_batch_size = opt.total_batch_size
65
  weights = opt.weights # initial training weights
66
  local_rank = opt.local_rank
 
69
  # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
70
 
71
  # Configure
72
+ init_seeds(2 + local_rank)
73
  with open(opt.data) as f:
74
  data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
75
  train_path = data_dict['train']
 
130
 
131
  # load model
132
  try:
133
+ ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
134
+ if k in model.state_dict() and model.state_dict()[k].shape == v.shape}
135
  model.load_state_dict(ckpt['model'], strict=False)
136
  except KeyError as e:
137
  s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
 
187
 
188
  # Trainloader
189
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
190
+ cache=opt.cache_images, rect=opt.rect, local_rank=local_rank,
191
+ world_size=opt.world_size)
192
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
193
  nb = len(dataloader) # number of batches
194
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
 
196
  # Testloader
197
  if local_rank in [-1, 0]:
198
  # local_rank is set to -1. Because only the first process is expected to do evaluation.
199
+ testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
200
+ cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
201
 
202
  # Model parameters
203
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
 
243
  if local_rank in [-1, 0]:
244
  w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
245
  image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
246
+ dataset.indices = random.choices(range(dataset.n), weights=image_weights,
247
+ k=dataset.n) # rand weighted idx
248
  # Broadcast.
249
  if local_rank != -1:
250
  indices = torch.zeros([dataset.n], dtype=torch.int)
 
404
  plot_results() # save as results.png
405
  print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
406
 
407
+ dist.destroy_process_group() if local_rank not in [-1, 0] else None
408
  torch.cuda.empty_cache()
409
  return results
410
 
 
433
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
434
  parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
435
  # Parameter For DDP.
436
+ parser.add_argument('--local_rank', type=int, default=-1,
437
+ help="Extra parameter for DDP implementation. Don't use it manually.")
438
  opt = parser.parse_args()
439
 
440
  last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run