glenn-jocher commited on
Commit
ca5b10b
1 Parent(s): 0070995

Update train.py (#2290)

Browse files

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Create train.py

Files changed (1) hide show
  1. train.py +16 -19
train.py CHANGED
@@ -146,8 +146,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
146
 
147
  # Results
148
  if ckpt.get('training_results') is not None:
149
- with open(results_file, 'w') as file:
150
- file.write(ckpt['training_results']) # write results.txt
151
 
152
  # Epochs
153
  start_epoch = ckpt['epoch'] + 1
@@ -354,7 +353,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
354
 
355
  # Write
356
  with open(results_file, 'a') as f:
357
- f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
358
  if len(opt.name) and opt.bucket:
359
  os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
360
 
@@ -375,15 +374,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
375
  best_fitness = fi
376
 
377
  # Save model
378
- save = (not opt.nosave) or (final_epoch and not opt.evolve)
379
- if save:
380
- with open(results_file, 'r') as f: # create checkpoint
381
- ckpt = {'epoch': epoch,
382
- 'best_fitness': best_fitness,
383
- 'training_results': f.read(),
384
- 'model': ema.ema,
385
- 'optimizer': None if final_epoch else optimizer.state_dict(),
386
- 'wandb_id': wandb_run.id if wandb else None}
387
 
388
  # Save last, best and delete
389
  torch.save(ckpt, last)
@@ -396,9 +393,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
396
  if rank in [-1, 0]:
397
  # Strip optimizers
398
  final = best if best.exists() else last # final model
399
- for f in [last, best]:
400
  if f.exists():
401
- strip_optimizer(f) # strip optimizers
402
  if opt.bucket:
403
  os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
404
 
@@ -415,17 +412,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
415
  # Test best.pt
416
  logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
417
  if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
418
- for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
419
  results, _, _ = test.test(opt.data,
420
  batch_size=batch_size * 2,
421
  imgsz=imgsz_test,
422
- conf_thres=conf,
423
- iou_thres=iou,
424
- model=attempt_load(final, device).half(),
425
  single_cls=opt.single_cls,
426
  dataloader=testloader,
427
  save_dir=save_dir,
428
- save_json=save_json,
429
  plots=False)
430
 
431
  else:
 
146
 
147
  # Results
148
  if ckpt.get('training_results') is not None:
149
+ results_file.write_text(ckpt['training_results']) # write results.txt
 
150
 
151
  # Epochs
152
  start_epoch = ckpt['epoch'] + 1
 
353
 
354
  # Write
355
  with open(results_file, 'a') as f:
356
+ f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
357
  if len(opt.name) and opt.bucket:
358
  os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
359
 
 
374
  best_fitness = fi
375
 
376
  # Save model
377
+ if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
378
+ ckpt = {'epoch': epoch,
379
+ 'best_fitness': best_fitness,
380
+ 'training_results': results_file.read_text(),
381
+ 'model': ema.ema,
382
+ 'optimizer': None if final_epoch else optimizer.state_dict(),
383
+ 'wandb_id': wandb_run.id if wandb else None}
 
 
384
 
385
  # Save last, best and delete
386
  torch.save(ckpt, last)
 
393
  if rank in [-1, 0]:
394
  # Strip optimizers
395
  final = best if best.exists() else last # final model
396
+ for f in last, best:
397
  if f.exists():
398
+ strip_optimizer(f)
399
  if opt.bucket:
400
  os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
401
 
 
412
  # Test best.pt
413
  logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
414
  if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
415
+ for m in (last, best) if best.exists() else (last): # speed, mAP tests
416
  results, _, _ = test.test(opt.data,
417
  batch_size=batch_size * 2,
418
  imgsz=imgsz_test,
419
+ conf_thres=0.001,
420
+ iou_thres=0.7,
421
+ model=attempt_load(m, device).half(),
422
  single_cls=opt.single_cls,
423
  dataloader=testloader,
424
  save_dir=save_dir,
425
+ save_json=True,
426
  plots=False)
427
 
428
  else: