glenn-jocher commited on
Commit
ce36905
1 Parent(s): 1e84a23
Files changed (2) hide show
  1. test.py +1 -1
  2. train.py +15 -15
test.py CHANGED
@@ -256,7 +256,7 @@ if __name__ == '__main__':
256
  opt.augment)
257
 
258
  elif opt.task == 'study': # run over a range of settings and save/plot
259
- for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.p5', 'yolov5x.pt', 'yolov3-spp.pt']:
260
  f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
261
  x = list(range(256, 1024, 32)) # x axis
262
  y = [] # y axis
 
256
  opt.augment)
257
 
258
  elif opt.task == 'study': # run over a range of settings and save/plot
259
+ for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
260
  f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
261
  x = list(range(256, 1024, 32)) # x axis
262
  y = [] # y axis
train.py CHANGED
@@ -108,30 +108,30 @@ def train(hyp):
108
  google_utils.attempt_download(weights)
109
  start_epoch, best_fitness = 0, 0.0
110
  if weights.endswith('.pt'): # pytorch format
111
- chkpt = torch.load(weights, map_location=device)
112
 
113
  # load model
114
  try:
115
- chkpt['model'] = \
116
- {k: v for k, v in chkpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
117
- model.load_state_dict(chkpt['model'], strict=False)
118
  except KeyError as e:
119
  s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
120
  % (opt.weights, opt.cfg, opt.weights)
121
  raise KeyError(s) from e
122
 
123
  # load optimizer
124
- if chkpt['optimizer'] is not None:
125
- optimizer.load_state_dict(chkpt['optimizer'])
126
- best_fitness = chkpt['best_fitness']
127
 
128
  # load results
129
- if chkpt.get('training_results') is not None:
130
  with open(results_file, 'w') as file:
131
- file.write(chkpt['training_results']) # write results.txt
132
 
133
- start_epoch = chkpt['epoch'] + 1
134
- del chkpt
135
 
136
  # Mixed precision training https://github.com/NVIDIA/apex
137
  if mixed_precision:
@@ -324,17 +324,17 @@ def train(hyp):
324
  save = (not opt.nosave) or (final_epoch and not opt.evolve)
325
  if save:
326
  with open(results_file, 'r') as f: # create checkpoint
327
- chkpt = {'epoch': epoch,
328
  'best_fitness': best_fitness,
329
  'training_results': f.read(),
330
  'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
331
  'optimizer': None if final_epoch else optimizer.state_dict()}
332
 
333
  # Save last, best and delete
334
- torch.save(chkpt, last)
335
  if (best_fitness == fi) and not final_epoch:
336
- torch.save(chkpt, best)
337
- del chkpt
338
 
339
  # end epoch ----------------------------------------------------------------------------------------------------
340
  # end training
 
108
  google_utils.attempt_download(weights)
109
  start_epoch, best_fitness = 0, 0.0
110
  if weights.endswith('.pt'): # pytorch format
111
+ ckpt = torch.load(weights, map_location=device) # load checkpoint
112
 
113
  # load model
114
  try:
115
+ ckpt['model'] = \
116
+ {k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
117
+ model.load_state_dict(ckpt['model'], strict=False)
118
  except KeyError as e:
119
  s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
120
  % (opt.weights, opt.cfg, opt.weights)
121
  raise KeyError(s) from e
122
 
123
  # load optimizer
124
+ if ckpt['optimizer'] is not None:
125
+ optimizer.load_state_dict(ckpt['optimizer'])
126
+ best_fitness = ckpt['best_fitness']
127
 
128
  # load results
129
+ if ckpt.get('training_results') is not None:
130
  with open(results_file, 'w') as file:
131
+ file.write(ckpt['training_results']) # write results.txt
132
 
133
+ start_epoch = ckpt['epoch'] + 1
134
+ del ckpt
135
 
136
  # Mixed precision training https://github.com/NVIDIA/apex
137
  if mixed_precision:
 
324
  save = (not opt.nosave) or (final_epoch and not opt.evolve)
325
  if save:
326
  with open(results_file, 'r') as f: # create checkpoint
327
+ ckpt = {'epoch': epoch,
328
  'best_fitness': best_fitness,
329
  'training_results': f.read(),
330
  'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
331
  'optimizer': None if final_epoch else optimizer.state_dict()}
332
 
333
  # Save last, best and delete
334
+ torch.save(ckpt, last)
335
  if (best_fitness == fi) and not final_epoch:
336
+ torch.save(ckpt, best)
337
+ del ckpt
338
 
339
  # end epoch ----------------------------------------------------------------------------------------------------
340
  # end training