Jirka Borovec glenn-jocher commited on
Commit
c67e722
1 Parent(s): 4d3680c

fix compatibility for hyper config (#1146)

Browse files

* fix/hyper

* Hyp giou check to train.py

* restore general.py

* train.py overwrite fix

* restore general.py and pep8 update

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (2) hide show
  1. train.py +8 -3
  2. utils/general.py +2 -2
train.py CHANGED
@@ -5,6 +5,7 @@ import random
5
  import shutil
6
  import time
7
  from pathlib import Path
 
8
 
9
  import math
10
  import numpy as np
@@ -430,9 +431,8 @@ if __name__ == '__main__':
430
  opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
431
  log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
432
 
433
- device = select_device(opt.device, batch_size=opt.batch_size)
434
-
435
  # DDP mode
 
436
  if opt.local_rank != -1:
437
  assert torch.cuda.device_count() > opt.local_rank
438
  torch.cuda.set_device(opt.local_rank)
@@ -441,11 +441,16 @@ if __name__ == '__main__':
441
  assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
442
  opt.batch_size = opt.total_batch_size // opt.world_size
443
 
444
- logger.info(opt)
445
  with open(opt.hyp) as f:
446
  hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
 
 
 
 
447
 
448
  # Train
 
449
  if not opt.evolve:
450
  tb_writer = None
451
  if opt.global_rank in [-1, 0]:
 
5
  import shutil
6
  import time
7
  from pathlib import Path
8
+ from warnings import warn
9
 
10
  import math
11
  import numpy as np
 
431
  opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
432
  log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
433
 
 
 
434
  # DDP mode
435
+ device = select_device(opt.device, batch_size=opt.batch_size)
436
  if opt.local_rank != -1:
437
  assert torch.cuda.device_count() > opt.local_rank
438
  torch.cuda.set_device(opt.local_rank)
 
441
  assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
442
  opt.batch_size = opt.total_batch_size // opt.world_size
443
 
444
+ # Hyperparameters
445
  with open(opt.hyp) as f:
446
  hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
447
+ if 'box' not in hyp:
448
+ warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
449
+ (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
450
+ hyp['box'] = hyp.pop('giou')
451
 
452
  # Train
453
+ logger.info(opt)
454
  if not opt.evolve:
455
  tb_writer = None
456
  if opt.global_rank in [-1, 0]:
utils/general.py CHANGED
@@ -1,18 +1,18 @@
1
  import glob
2
  import logging
3
- import math
4
  import os
5
  import platform
6
  import random
 
7
  import shutil
8
  import subprocess
9
  import time
10
- import re
11
  from contextlib import contextmanager
12
  from copy import copy
13
  from pathlib import Path
14
 
15
  import cv2
 
16
  import matplotlib
17
  import matplotlib.pyplot as plt
18
  import numpy as np
 
1
  import glob
2
  import logging
 
3
  import os
4
  import platform
5
  import random
6
+ import re
7
  import shutil
8
  import subprocess
9
  import time
 
10
  from contextlib import contextmanager
11
  from copy import copy
12
  from pathlib import Path
13
 
14
  import cv2
15
+ import math
16
  import matplotlib
17
  import matplotlib.pyplot as plt
18
  import numpy as np