glenn-jocher commited on
Commit
71dd276
1 Parent(s): ec1d849

Improved model+EMA checkpointing 2 (#2295)

Browse files
Files changed (2) hide show
  1. test.py +1 -0
  2. train.py +3 -4
test.py CHANGED
@@ -269,6 +269,7 @@ def test(data,
269
  print(f'pycocotools unable to run: {e}')
270
 
271
  # Return results
 
272
  if not training:
273
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
274
  print(f"Results saved to {save_dir}{s}")
 
269
  print(f'pycocotools unable to run: {e}')
270
 
271
  # Return results
272
+ model.float() # for training
273
  if not training:
274
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
275
  print(f"Results saved to {save_dir}{s}")
train.py CHANGED
@@ -4,6 +4,7 @@ import math
4
  import os
5
  import random
6
  import time
 
7
  from pathlib import Path
8
  from threading import Thread
9
 
@@ -381,8 +382,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
381
  ckpt = {'epoch': epoch,
382
  'best_fitness': best_fitness,
383
  'training_results': results_file.read_text(),
384
- 'model': (model.module if is_parallel(model) else model).half(),
385
- 'ema': (ema.ema.half(), ema.updates),
386
  'optimizer': optimizer.state_dict(),
387
  'wandb_id': wandb_run.id if wandb else None}
388
 
@@ -392,8 +393,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
392
  torch.save(ckpt, best)
393
  del ckpt
394
 
395
- model.float(), ema.ema.float()
396
-
397
  # end epoch ----------------------------------------------------------------------------------------------------
398
  # end training
399
 
 
4
  import os
5
  import random
6
  import time
7
+ from copy import deepcopy
8
  from pathlib import Path
9
  from threading import Thread
10
 
 
382
  ckpt = {'epoch': epoch,
383
  'best_fitness': best_fitness,
384
  'training_results': results_file.read_text(),
385
+ 'model': deepcopy(model.module if is_parallel(model) else model).half(),
386
+ 'ema': (deepcopy(ema.ema).half(), ema.updates),
387
  'optimizer': optimizer.state_dict(),
388
  'wandb_id': wandb_run.id if wandb else None}
389
 
 
393
  torch.save(ckpt, best)
394
  del ckpt
395
 
 
 
396
  # end epoch ----------------------------------------------------------------------------------------------------
397
  # end training
398