glenn-jocher commited on
Commit
31f3310
1 Parent(s): 19e68e8

assert best possible recall > 0.9 before training

Browse files
Files changed (3) hide show
  1. train.py +4 -1
  2. utils/datasets.py +13 -11
  3. utils/utils.py +13 -0
train.py CHANGED
@@ -191,7 +191,7 @@ def train(hyp):
191
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
192
  model.names = data_dict['names']
193
 
194
- # class frequency
195
  labels = np.concatenate(dataset.labels, 0)
196
  c = torch.tensor(labels[:, 0]) # classes
197
  # cf = torch.bincount(c.long(), minlength=nc) + 1.
@@ -199,6 +199,9 @@ def train(hyp):
199
  plot_labels(labels)
200
  tb_writer.add_histogram('classes', c, 0)
201
 
 
 
 
202
  # Exponential moving average
203
  ema = torch_utils.ModelEMA(model)
204
 
 
191
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
192
  model.names = data_dict['names']
193
 
194
+ # Class frequency
195
  labels = np.concatenate(dataset.labels, 0)
196
  c = torch.tensor(labels[:, 0]) # classes
197
  # cf = torch.bincount(c.long(), minlength=nc) + 1.
 
199
  plot_labels(labels)
200
  tb_writer.add_histogram('classes', c, 0)
201
 
202
+ # Check anchors
203
+ check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'])
204
+
205
  # Exponential moving average
206
  ema = torch_utils.ModelEMA(model)
207
 
utils/datasets.py CHANGED
@@ -291,20 +291,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
291
  self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
292
  for x in self.img_files]
293
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
295
  if self.rect:
296
- # Read image shapes (wh)
297
- sp = path.replace('.txt', '') + '.shapes' # shapefile path
298
- try:
299
- with open(sp, 'r') as f: # read existing shapefile
300
- s = [x.split() for x in f.read().splitlines()]
301
- assert len(s) == n, 'Shapefile out of sync'
302
- except:
303
- s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
304
- np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
305
-
306
  # Sort by aspect ratio
307
- s = np.array(s, dtype=np.float64)
308
  ar = s[:, 1] / s[:, 0] # aspect ratio
309
  irect = ar.argsort()
310
  self.img_files = [self.img_files[i] for i in irect]
 
291
  self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
292
  for x in self.img_files]
293
 
294
+ # Read image shapes (wh)
295
+ sp = path.replace('.txt', '') + '.shapes' # shapefile path
296
+ try:
297
+ with open(sp, 'r') as f: # read existing shapefile
298
+ s = [x.split() for x in f.read().splitlines()]
299
+ assert len(s) == n, 'Shapefile out of sync'
300
+ except:
301
+ s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
302
+ np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
303
+
304
+ self.shapes = np.array(s, dtype=np.float64)
305
+
306
  # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
307
  if self.rect:
 
 
 
 
 
 
 
 
 
 
308
  # Sort by aspect ratio
309
+ s = self.shapes # wh
310
  ar = s[:, 1] / s[:, 0] # aspect ratio
311
  irect = ar.argsort()
312
  self.img_files = [self.img_files[i] for i in irect]
utils/utils.py CHANGED
@@ -51,6 +51,19 @@ def check_img_size(img_size, s=32):
51
  return make_divisible(img_size, s) # nearest gs-multiple
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def make_divisible(x, divisor):
55
  # Returns x evenly divisble by divisor
56
  return math.ceil(x / divisor) * divisor
 
51
  return make_divisible(img_size, s) # nearest gs-multiple
52
 
53
 
54
+ def check_best_possible_recall(dataset, anchors, thr):
55
+ # Check best possible recall of dataset with current anchors
56
+ wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])) # width-height
57
+ ratio = wh[:, None] / anchors.view(-1, 2)[None] # ratio
58
+ m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
59
+ bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
60
+ mr = (m < thr).float().mean() # match ratio
61
+ print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
62
+ print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
63
+ assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
64
+ 'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
65
+
66
+
67
  def make_divisible(x, divisor):
68
  # Returns x evenly divisble by divisor
69
  return math.ceil(x / divisor) * divisor