glenn-jocher commited on
Commit
4250f84
1 Parent(s): 8d2d6d2

Update PR curve (#1428)

Browse files

* Update PR curve

* legend outside

* list(Path().glob())

Files changed (3) hide show
  1. test.py +1 -1
  2. utils/metrics.py +25 -13
  3. utils/plots.py +9 -11
test.py CHANGED
@@ -213,7 +213,7 @@ def test(data,
213
  # Compute statistics
214
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
215
  if len(stats) and stats[0].any():
216
- p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
217
  p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
218
  mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
219
  nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
 
213
  # Compute statistics
214
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
215
  if len(stats) and stats[0].any():
216
+ p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
217
  p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
218
  mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
219
  nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
utils/metrics.py CHANGED
@@ -1,5 +1,7 @@
1
  # Model validation metrics
2
 
 
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
 
@@ -10,7 +12,7 @@ def fitness(x):
10
  return (x[:, :4] * w).sum(1)
11
 
12
 
13
- def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'):
14
  """ Compute the average precision, given the recall and precision curves.
15
  Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
16
  # Arguments
@@ -19,7 +21,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re
19
  pred_cls: Predicted object classes (nparray).
20
  target_cls: True object classes (nparray).
21
  plot: Plot precision-recall curve at mAP@0.5
22
- fname: Plot filename
23
  # Returns
24
  The average precision as computed in py-faster-rcnn.
25
  """
@@ -66,17 +68,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re
66
  f1 = 2 * p * r / (p + r + 1e-16)
67
 
68
  if plot:
69
- py = np.stack(py, axis=1)
70
- fig, ax = plt.subplots(1, 1, figsize=(5, 5))
71
- ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
72
- ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
73
- ax.set_xlabel('Recall')
74
- ax.set_ylabel('Precision')
75
- ax.set_xlim(0, 1)
76
- ax.set_ylim(0, 1)
77
- plt.legend()
78
- fig.tight_layout()
79
- fig.savefig(fname, dpi=200)
80
 
81
  return p, r, ap, f1, unique_classes.astype('int32')
82
 
@@ -108,3 +100,23 @@ def compute_ap(recall, precision):
108
  ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
109
 
110
  return ap, mpre, mrec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Model validation metrics
2
 
3
+ from pathlib import Path
4
+
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
 
 
12
  return (x[:, :4] * w).sum(1)
13
 
14
 
15
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]):
16
  """ Compute the average precision, given the recall and precision curves.
17
  Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
18
  # Arguments
 
21
  pred_cls: Predicted object classes (nparray).
22
  target_cls: True object classes (nparray).
23
  plot: Plot precision-recall curve at mAP@0.5
24
+ save_dir: Plot save directory
25
  # Returns
26
  The average precision as computed in py-faster-rcnn.
27
  """
 
68
  f1 = 2 * p * r / (p + r + 1e-16)
69
 
70
  if plot:
71
+ plot_pr_curve(px, py, ap, save_dir, names)
 
 
 
 
 
 
 
 
 
 
72
 
73
  return p, r, ap, f1, unique_classes.astype('int32')
74
 
 
100
  ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
101
 
102
  return ap, mpre, mrec
103
+
104
+
105
+ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
106
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6))
107
+ py = np.stack(py, axis=1)
108
+
109
+ if 0 < len(names) < 21: # show mAP in legend if < 10 classes
110
+ for i, y in enumerate(py.T):
111
+ ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision)
112
+ else:
113
+ ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
114
+
115
+ ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
116
+ ax.set_xlabel('Recall')
117
+ ax.set_ylabel('Precision')
118
+ ax.set_xlim(0, 1)
119
+ ax.set_ylim(0, 1)
120
+ plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
121
+ fig.tight_layout()
122
+ fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
utils/plots.py CHANGED
@@ -65,7 +65,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
65
  cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
66
 
67
 
68
- def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
69
  # Compares the two methods for width-height anchor multiplication
70
  # https://github.com/ultralytics/yolov3/issues/168
71
  x = np.arange(-4.0, 4.0, .1)
@@ -200,7 +200,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
200
  plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
201
 
202
 
203
- def plot_test_txt(): # from utils.general import *; plot_test()
204
  # Plot test.txt histograms
205
  x = np.loadtxt('test.txt', dtype=np.float32)
206
  box = xyxy2xywh(x[:, :4])
@@ -217,7 +217,7 @@ def plot_test_txt(): # from utils.general import *; plot_test()
217
  plt.savefig('hist1d.png', dpi=200)
218
 
219
 
220
- def plot_targets_txt(): # from utils.general import *; plot_targets_txt()
221
  # Plot targets.txt histograms
222
  x = np.loadtxt('targets.txt', dtype=np.float32).T
223
  s = ['x targets', 'y targets', 'width targets', 'height targets']
@@ -230,7 +230,7 @@ def plot_targets_txt(): # from utils.general import *; plot_targets_txt()
230
  plt.savefig('targets.jpg', dpi=200)
231
 
232
 
233
- def plot_study_txt(f='study.txt', x=None): # from utils.general import *; plot_study_txt()
234
  # Plot study.txt generated by test.py
235
  fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
236
  ax = ax.ravel()
@@ -294,7 +294,7 @@ def plot_labels(labels, save_dir=''):
294
  pass
295
 
296
 
297
- def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general import *; plot_evolution()
298
  # Plot hyperparameter evolution results in evolve.txt
299
  with open(yaml_file) as f:
300
  hyp = yaml.load(f, Loader=yaml.FullLoader)
@@ -318,7 +318,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im
318
  print('\nPlot saved as evolve.png')
319
 
320
 
321
- def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay()
322
  # Plot training 'results*.txt', overlaying train and val losses
323
  s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
324
  t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
@@ -342,20 +342,18 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_
342
 
343
 
344
  def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
345
- # from utils.general import *; plot_results(save_dir='runs/train/exp0')
346
- # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
347
  fig, ax = plt.subplots(2, 5, figsize=(12, 6))
348
  ax = ax.ravel()
349
  s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
350
  'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
351
  if bucket:
352
- # os.system('rm -rf storage.googleapis.com')
353
  # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
354
  files = ['results%g.txt' % x for x in id]
355
  c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
356
  os.system(c)
357
  else:
358
- files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt')
359
  assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
360
  for fi, f in enumerate(files):
361
  try:
@@ -367,7 +365,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
367
  if i in [0, 1, 2, 5, 6, 7]:
368
  y[y == 0] = np.nan # don't show zero loss values
369
  # y /= y[0] # normalize
370
- label = labels[fi] if len(labels) else Path(f).stem
371
  ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)
372
  ax[i].set_title(s[i])
373
  # if i in [5, 6, 7]: # share train and val loss y axes
 
65
  cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
66
 
67
 
68
+ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
69
  # Compares the two methods for width-height anchor multiplication
70
  # https://github.com/ultralytics/yolov3/issues/168
71
  x = np.arange(-4.0, 4.0, .1)
 
200
  plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
201
 
202
 
203
+ def plot_test_txt(): # from utils.plots import *; plot_test()
204
  # Plot test.txt histograms
205
  x = np.loadtxt('test.txt', dtype=np.float32)
206
  box = xyxy2xywh(x[:, :4])
 
217
  plt.savefig('hist1d.png', dpi=200)
218
 
219
 
220
+ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
221
  # Plot targets.txt histograms
222
  x = np.loadtxt('targets.txt', dtype=np.float32).T
223
  s = ['x targets', 'y targets', 'width targets', 'height targets']
 
230
  plt.savefig('targets.jpg', dpi=200)
231
 
232
 
233
+ def plot_study_txt(f='study.txt', x=None): # from utils.plots import *; plot_study_txt()
234
  # Plot study.txt generated by test.py
235
  fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
236
  ax = ax.ravel()
 
294
  pass
295
 
296
 
297
+ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
298
  # Plot hyperparameter evolution results in evolve.txt
299
  with open(yaml_file) as f:
300
  hyp = yaml.load(f, Loader=yaml.FullLoader)
 
318
  print('\nPlot saved as evolve.png')
319
 
320
 
321
+ def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
322
  # Plot training 'results*.txt', overlaying train and val losses
323
  s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
324
  t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
 
342
 
343
 
344
  def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
345
+ # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
 
346
  fig, ax = plt.subplots(2, 5, figsize=(12, 6))
347
  ax = ax.ravel()
348
  s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
349
  'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
350
  if bucket:
 
351
  # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
352
  files = ['results%g.txt' % x for x in id]
353
  c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
354
  os.system(c)
355
  else:
356
+ files = list(Path(save_dir).glob('results*.txt'))
357
  assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
358
  for fi, f in enumerate(files):
359
  try:
 
365
  if i in [0, 1, 2, 5, 6, 7]:
366
  y[y == 0] = np.nan # don't show zero loss values
367
  # y /= y[0] # normalize
368
+ label = labels[fi] if len(labels) else f.stem
369
  ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)
370
  ax[i].set_title(s[i])
371
  # if i in [5, 6, 7]: # share train and val loss y axes