glenn-jocher commited on
Commit
5fac5ad
1 Parent(s): 9eae82e

Precision-Recall Curve Feature Addition (#1107)

Browse files

* initial commit

* Update general.py

Indent update

* Update general.py

refactor duplicate code

* 200 dpi

Files changed (3) hide show
  1. test.py +10 -11
  2. train.py +2 -4
  3. utils/general.py +21 -16
test.py CHANGED
@@ -30,9 +30,9 @@ def test(data,
30
  verbose=False,
31
  model=None,
32
  dataloader=None,
33
- save_dir='',
34
- merge=False,
35
- save_txt=False):
36
  # Initialize/load model and set device
37
  training = model is not None
38
  if training: # called by train.py
@@ -41,7 +41,7 @@ def test(data,
41
  else: # called directly
42
  set_logging()
43
  device = select_device(opt.device, batch_size=batch_size)
44
- merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
45
  if save_txt:
46
  out = Path('inference/output')
47
  if os.path.exists(out):
@@ -49,7 +49,7 @@ def test(data,
49
  os.makedirs(out) # make new output folder
50
 
51
  # Remove previous
52
- for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
53
  os.remove(f)
54
 
55
  # Load model
@@ -110,7 +110,7 @@ def test(data,
110
 
111
  # Run NMS
112
  t = time_synchronized()
113
- output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
114
  t1 += time_synchronized() - t
115
 
116
  # Statistics per image
@@ -186,16 +186,16 @@ def test(data,
186
  stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
187
 
188
  # Plot images
189
- if batch_i < 1:
190
- f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
191
  plot_images(img, targets, paths, str(f), names) # ground truth
192
- f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
193
  plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
194
 
195
  # Compute statistics
196
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
197
  if len(stats) and stats[0].any():
198
- p, r, ap, f1, ap_class = ap_per_class(*stats)
199
  p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
200
  mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
201
  nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
@@ -261,7 +261,6 @@ if __name__ == '__main__':
261
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
262
  parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
263
  parser.add_argument('--augment', action='store_true', help='augmented inference')
264
- parser.add_argument('--merge', action='store_true', help='use Merge NMS')
265
  parser.add_argument('--verbose', action='store_true', help='report mAP by class')
266
  parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
267
  opt = parser.parse_args()
 
30
  verbose=False,
31
  model=None,
32
  dataloader=None,
33
+ save_dir=Path(''), # for saving images
34
+ save_txt=False, # for auto-labelling
35
+ plots=True):
36
  # Initialize/load model and set device
37
  training = model is not None
38
  if training: # called by train.py
 
41
  else: # called directly
42
  set_logging()
43
  device = select_device(opt.device, batch_size=batch_size)
44
+ save_txt = opt.save_txt # save *.txt labels
45
  if save_txt:
46
  out = Path('inference/output')
47
  if os.path.exists(out):
 
49
  os.makedirs(out) # make new output folder
50
 
51
  # Remove previous
52
+ for f in glob.glob(str(save_dir / 'test_batch*.jpg')):
53
  os.remove(f)
54
 
55
  # Load model
 
110
 
111
  # Run NMS
112
  t = time_synchronized()
113
+ output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
114
  t1 += time_synchronized() - t
115
 
116
  # Statistics per image
 
186
  stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
187
 
188
  # Plot images
189
+ if plots and batch_i < 1:
190
+ f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
191
  plot_images(img, targets, paths, str(f), names) # ground truth
192
+ f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
193
  plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
194
 
195
  # Compute statistics
196
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
197
  if len(stats) and stats[0].any():
198
+ p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
199
  p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
200
  mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
201
  nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
 
261
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
262
  parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
263
  parser.add_argument('--augment', action='store_true', help='augmented inference')
 
264
  parser.add_argument('--verbose', action='store_true', help='report mAP by class')
265
  parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
266
  opt = parser.parse_args()
train.py CHANGED
@@ -1,5 +1,4 @@
1
  import argparse
2
- import glob
3
  import logging
4
  import math
5
  import os
@@ -309,15 +308,14 @@ def train(hyp, opt, device, tb_writer=None):
309
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
310
  final_epoch = epoch + 1 == epochs
311
  if not opt.notest or final_epoch: # Calculate mAP
312
- if final_epoch: # replot predictions
313
- [os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)]
314
  results, maps, times = test.test(opt.data,
315
  batch_size=total_batch_size,
316
  imgsz=imgsz_test,
317
  model=ema.ema,
318
  single_cls=opt.single_cls,
319
  dataloader=testloader,
320
- save_dir=log_dir)
 
321
 
322
  # Write
323
  with open(results_file, 'a') as f:
 
1
  import argparse
 
2
  import logging
3
  import math
4
  import os
 
308
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
309
  final_epoch = epoch + 1 == epochs
310
  if not opt.notest or final_epoch: # Calculate mAP
 
 
311
  results, maps, times = test.test(opt.data,
312
  batch_size=total_batch_size,
313
  imgsz=imgsz_test,
314
  model=ema.ema,
315
  single_cls=opt.single_cls,
316
  dataloader=testloader,
317
+ save_dir=log_dir,
318
+ plots=epoch == 0 or final_epoch) # plot first and last
319
 
320
  # Write
321
  with open(results_file, 'a') as f:
utils/general.py CHANGED
@@ -245,14 +245,16 @@ def clip_coords(boxes, img_shape):
245
  boxes[:, 3].clamp_(0, img_shape[0]) # y2
246
 
247
 
248
- def ap_per_class(tp, conf, pred_cls, target_cls):
249
  """ Compute the average precision, given the recall and precision curves.
250
  Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
251
  # Arguments
252
- tp: True positives (nparray, nx1 or nx10).
253
  conf: Objectness value from 0-1 (nparray).
254
- pred_cls: Predicted object classes (nparray).
255
- target_cls: True object classes (nparray).
 
 
256
  # Returns
257
  The average precision as computed in py-faster-rcnn.
258
  """
@@ -265,6 +267,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
265
  unique_classes = np.unique(target_cls)
266
 
267
  # Create Precision-Recall curve and compute AP for each class
 
268
  pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
269
  s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
270
  ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
@@ -289,22 +292,26 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
289
  p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
290
 
291
  # AP from recall-precision curve
 
292
  for j in range(tp.shape[1]):
293
  ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
294
 
295
- # Plot
296
- # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
297
- # ax.plot(recall, precision)
298
- # ax.set_xlabel('Recall')
299
- # ax.set_ylabel('Precision')
300
- # ax.set_xlim(0, 1.01)
301
- # ax.set_ylim(0, 1.01)
302
- # fig.tight_layout()
303
- # fig.savefig('PR_curve.png', dpi=300)
304
-
305
  # Compute F1 score (harmonic mean of precision and recall)
306
  f1 = 2 * p * r / (p + r + 1e-16)
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  return p, r, ap, f1, unique_classes.astype('int32')
309
 
310
 
@@ -1011,8 +1018,6 @@ def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
1011
  def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
1012
  tl = 3 # line thickness
1013
  tf = max(tl - 1, 1) # font thickness
1014
- if os.path.isfile(fname): # do not overwrite
1015
- return None
1016
 
1017
  if isinstance(images, torch.Tensor):
1018
  images = images.cpu().float().numpy()
 
245
  boxes[:, 3].clamp_(0, img_shape[0]) # y2
246
 
247
 
248
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'):
249
  """ Compute the average precision, given the recall and precision curves.
250
  Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
251
  # Arguments
252
+ tp: True positives (nparray, nx1 or nx10).
253
  conf: Objectness value from 0-1 (nparray).
254
+ pred_cls: Predicted object classes (nparray).
255
+ target_cls: True object classes (nparray).
256
+ plot: Plot precision-recall curve at mAP@0.5
257
+ fname: Plot filename
258
  # Returns
259
  The average precision as computed in py-faster-rcnn.
260
  """
 
267
  unique_classes = np.unique(target_cls)
268
 
269
  # Create Precision-Recall curve and compute AP for each class
270
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
271
  pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
272
  s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
273
  ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
 
292
  p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
293
 
294
  # AP from recall-precision curve
295
+ py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5
296
  for j in range(tp.shape[1]):
297
  ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
298
 
 
 
 
 
 
 
 
 
 
 
299
  # Compute F1 score (harmonic mean of precision and recall)
300
  f1 = 2 * p * r / (p + r + 1e-16)
301
 
302
+ if plot:
303
+ py = np.stack(py, axis=1)
304
+ fig, ax = plt.subplots(1, 1, figsize=(5, 5))
305
+ ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
306
+ ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes')
307
+ ax.set_xlabel('Recall')
308
+ ax.set_ylabel('Precision')
309
+ ax.set_xlim(0, 1)
310
+ ax.set_ylim(0, 1)
311
+ plt.legend()
312
+ fig.tight_layout()
313
+ fig.savefig(fname, dpi=200)
314
+
315
  return p, r, ap, f1, unique_classes.astype('int32')
316
 
317
 
 
1018
  def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
1019
  tl = 3 # line thickness
1020
  tf = max(tl - 1, 1) # font thickness
 
 
1021
 
1022
  if isinstance(images, torch.Tensor):
1023
  images = images.cpu().float().numpy()