glenn-jocher commited on
Commit
ec1d849
1 Parent(s): ca5b10b

Improved model+EMA checkpointing (#2292)

Browse files

* Enhanced model+EMA checkpointing

* update

* bug fix

* bug fix 2

* always save optimizer

* ema half

* remove model.float()

* model half

* carry ema/model in fp32

* rm model.float()

* both to float always

* cleanup

* cleanup

Files changed (3) hide show
  1. test.py +0 -1
  2. train.py +16 -9
  3. utils/general.py +2 -2
test.py CHANGED
@@ -272,7 +272,6 @@ def test(data,
272
  if not training:
273
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
274
  print(f"Results saved to {save_dir}{s}")
275
- model.float() # for training
276
  maps = np.zeros(nc) + map
277
  for i, c in enumerate(ap_class):
278
  maps[c] = ap[i]
 
272
  if not training:
273
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
274
  print(f"Results saved to {save_dir}{s}")
 
275
  maps = np.zeros(nc) + map
276
  for i, c in enumerate(ap_class):
277
  maps[c] = ap[i]
train.py CHANGED
@@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
31
  from utils.google_utils import attempt_download
32
  from utils.loss import ComputeLoss
33
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
34
- from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
35
 
36
  logger = logging.getLogger(__name__)
37
 
@@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
136
  id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
137
  loggers = {'wandb': wandb} # loggers dict
138
 
 
 
 
139
  # Resume
140
  start_epoch, best_fitness = 0, 0.0
141
  if pretrained:
@@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
144
  optimizer.load_state_dict(ckpt['optimizer'])
145
  best_fitness = ckpt['best_fitness']
146
 
 
 
 
 
 
147
  # Results
148
  if ckpt.get('training_results') is not None:
149
  results_file.write_text(ckpt['training_results']) # write results.txt
@@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
173
  model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
174
  logger.info('Using SyncBatchNorm()')
175
 
176
- # EMA
177
- ema = ModelEMA(model) if rank in [-1, 0] else None
178
-
179
  # DDP mode
180
  if cuda and rank != -1:
181
  model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
@@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
191
 
192
  # Process 0
193
  if rank in [-1, 0]:
194
- ema.updates = start_epoch * nb // accumulate # set EMA updates
195
  testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
196
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
197
  world_size=opt.world_size, workers=opt.workers,
@@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
335
  # DDP process 0 or single-GPU
336
  if rank in [-1, 0]:
337
  # mAP
338
- if ema:
339
- ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
340
  final_epoch = epoch + 1 == epochs
341
  if not opt.notest or final_epoch: # Calculate mAP
342
  results, maps, times = test.test(opt.data,
@@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
378
  ckpt = {'epoch': epoch,
379
  'best_fitness': best_fitness,
380
  'training_results': results_file.read_text(),
381
- 'model': ema.ema,
382
- 'optimizer': None if final_epoch else optimizer.state_dict(),
 
383
  'wandb_id': wandb_run.id if wandb else None}
384
 
385
  # Save last, best and delete
@@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
387
  if best_fitness == fi:
388
  torch.save(ckpt, best)
389
  del ckpt
 
 
 
390
  # end epoch ----------------------------------------------------------------------------------------------------
391
  # end training
392
 
 
31
  from utils.google_utils import attempt_download
32
  from utils.loss import ComputeLoss
33
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
34
+ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
35
 
36
  logger = logging.getLogger(__name__)
37
 
 
136
  id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
137
  loggers = {'wandb': wandb} # loggers dict
138
 
139
+ # EMA
140
+ ema = ModelEMA(model) if rank in [-1, 0] else None
141
+
142
  # Resume
143
  start_epoch, best_fitness = 0, 0.0
144
  if pretrained:
 
147
  optimizer.load_state_dict(ckpt['optimizer'])
148
  best_fitness = ckpt['best_fitness']
149
 
150
+ # EMA
151
+ if ema and ckpt.get('ema'):
152
+ ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
153
+ ema.updates = ckpt['ema'][1]
154
+
155
  # Results
156
  if ckpt.get('training_results') is not None:
157
  results_file.write_text(ckpt['training_results']) # write results.txt
 
181
  model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
182
  logger.info('Using SyncBatchNorm()')
183
 
 
 
 
184
  # DDP mode
185
  if cuda and rank != -1:
186
  model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
 
196
 
197
  # Process 0
198
  if rank in [-1, 0]:
 
199
  testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
200
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
201
  world_size=opt.world_size, workers=opt.workers,
 
339
  # DDP process 0 or single-GPU
340
  if rank in [-1, 0]:
341
  # mAP
342
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
 
343
  final_epoch = epoch + 1 == epochs
344
  if not opt.notest or final_epoch: # Calculate mAP
345
  results, maps, times = test.test(opt.data,
 
381
  ckpt = {'epoch': epoch,
382
  'best_fitness': best_fitness,
383
  'training_results': results_file.read_text(),
384
+ 'model': (model.module if is_parallel(model) else model).half(),
385
+ 'ema': (ema.ema.half(), ema.updates),
386
+ 'optimizer': optimizer.state_dict(),
387
  'wandb_id': wandb_run.id if wandb else None}
388
 
389
  # Save last, best and delete
 
391
  if best_fitness == fi:
392
  torch.save(ckpt, best)
393
  del ckpt
394
+
395
+ model.float(), ema.ema.float()
396
+
397
  # end epoch ----------------------------------------------------------------------------------------------------
398
  # end training
399
 
utils/general.py CHANGED
@@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
484
  def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
485
  # Strip optimizer from 'f' to finalize training, optionally save as 's'
486
  x = torch.load(f, map_location=torch.device('cpu'))
487
- for key in 'optimizer', 'training_results', 'wandb_id':
488
- x[key] = None
489
  x['epoch'] = -1
490
  x['model'].half() # to FP16
491
  for p in x['model'].parameters():
 
484
  def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
485
  # Strip optimizer from 'f' to finalize training, optionally save as 's'
486
  x = torch.load(f, map_location=torch.device('cpu'))
487
+ for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
488
+ x[k] = None
489
  x['epoch'] = -1
490
  x['model'].half() # to FP16
491
  for p in x['model'].parameters():