glenn-jocher commited on
Commit
f639e14
1 Parent(s): 2a835c7

Metric-Confidence plots feature addition (#2057)

Browse files

* Metric-Confidence plots feature addition

* cleanup

* Metric-Confidence plots feature addition

* cleanup

* Update run-once lines

* cleanup

* save all 4 curves to wandb

Files changed (3) hide show
  1. test.py +1 -1
  2. train.py +1 -1
  3. utils/metrics.py +38 -15
test.py CHANGED
@@ -215,7 +215,7 @@ def test(data,
215
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
216
  if len(stats) and stats[0].any():
217
  p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
218
- p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
219
  mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
220
  nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
221
  else:
 
215
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
216
  if len(stats) and stats[0].any():
217
  p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
218
+ ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
219
  mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
220
  nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
221
  else:
train.py CHANGED
@@ -403,7 +403,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
403
  if plots:
404
  plot_results(save_dir=save_dir) # save as results.png
405
  if wandb:
406
- files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
407
  wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
408
  if (save_dir / f).exists()]})
409
  if opt.log_artifacts:
 
403
  if plots:
404
  plot_results(save_dir=save_dir) # save as results.png
405
  if wandb:
406
+ files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
407
  wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
408
  if (save_dir / f).exists()]})
409
  if opt.log_artifacts:
utils/metrics.py CHANGED
@@ -15,7 +15,7 @@ def fitness(x):
15
  return (x[:, :4] * w).sum(1)
16
 
17
 
18
- def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]):
19
  """ Compute the average precision, given the recall and precision curves.
20
  Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
21
  # Arguments
@@ -35,12 +35,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision
35
 
36
  # Find unique classes
37
  unique_classes = np.unique(target_cls)
 
38
 
39
  # Create Precision-Recall curve and compute AP for each class
40
  px, py = np.linspace(0, 1, 1000), [] # for plotting
41
- pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
42
- s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
43
- ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
44
  for ci, c in enumerate(unique_classes):
45
  i = pred_cls == c
46
  n_l = (target_cls == c).sum() # number of labels
@@ -55,25 +54,28 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision
55
 
56
  # Recall
57
  recall = tpc / (n_l + 1e-16) # recall curve
58
- r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
59
 
60
  # Precision
61
  precision = tpc / (tpc + fpc) # precision curve
62
- p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
63
 
64
  # AP from recall-precision curve
65
  for j in range(tp.shape[1]):
66
  ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
67
- if plot and (j == 0):
68
  py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
69
 
70
- # Compute F1 score (harmonic mean of precision and recall)
71
  f1 = 2 * p * r / (p + r + 1e-16)
72
-
73
  if plot:
74
- plot_pr_curve(px, py, ap, save_dir, names)
 
 
 
75
 
76
- return p, r, ap, f1, unique_classes.astype('int32')
 
77
 
78
 
79
  def compute_ap(recall, precision):
@@ -181,13 +183,14 @@ class ConfusionMatrix:
181
 
182
  # Plots ----------------------------------------------------------------------------------------------------------------
183
 
184
- def plot_pr_curve(px, py, ap, save_dir='.', names=()):
 
185
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
186
  py = np.stack(py, axis=1)
187
 
188
- if 0 < len(names) < 21: # show mAP in legend if < 10 classes
189
  for i, y in enumerate(py.T):
190
- ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision)
191
  else:
192
  ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
193
 
@@ -197,4 +200,24 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
197
  ax.set_xlim(0, 1)
198
  ax.set_ylim(0, 1)
199
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
200
- fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  return (x[:, :4] * w).sum(1)
16
 
17
 
18
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
19
  """ Compute the average precision, given the recall and precision curves.
20
  Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
21
  # Arguments
 
35
 
36
  # Find unique classes
37
  unique_classes = np.unique(target_cls)
38
+ nc = unique_classes.shape[0] # number of classes, number of detections
39
 
40
  # Create Precision-Recall curve and compute AP for each class
41
  px, py = np.linspace(0, 1, 1000), [] # for plotting
42
+ ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
 
 
43
  for ci, c in enumerate(unique_classes):
44
  i = pred_cls == c
45
  n_l = (target_cls == c).sum() # number of labels
 
54
 
55
  # Recall
56
  recall = tpc / (n_l + 1e-16) # recall curve
57
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
58
 
59
  # Precision
60
  precision = tpc / (tpc + fpc) # precision curve
61
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
62
 
63
  # AP from recall-precision curve
64
  for j in range(tp.shape[1]):
65
  ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
66
+ if plot and j == 0:
67
  py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
68
 
69
+ # Compute F1 (harmonic mean of precision and recall)
70
  f1 = 2 * p * r / (p + r + 1e-16)
 
71
  if plot:
72
+ plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
73
+ plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
74
+ plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
75
+ plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
76
 
77
+ i = f1.mean(0).argmax() # max F1 index
78
+ return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
79
 
80
 
81
  def compute_ap(recall, precision):
 
183
 
184
  # Plots ----------------------------------------------------------------------------------------------------------------
185
 
186
+ def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
187
+ # Precision-recall curve
188
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
189
  py = np.stack(py, axis=1)
190
 
191
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
192
  for i, y in enumerate(py.T):
193
+ ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
194
  else:
195
  ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
196
 
 
200
  ax.set_xlim(0, 1)
201
  ax.set_ylim(0, 1)
202
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
203
+ fig.savefig(Path(save_dir), dpi=250)
204
+
205
+
206
+ def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
207
+ # Metric-confidence curve
208
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
209
+
210
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
211
+ for i, y in enumerate(py):
212
+ ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
213
+ else:
214
+ ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
215
+
216
+ y = py.mean(0)
217
+ ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
218
+ ax.set_xlabel(xlabel)
219
+ ax.set_ylabel(ylabel)
220
+ ax.set_xlim(0, 1)
221
+ ax.set_ylim(0, 1)
222
+ plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
223
+ fig.savefig(Path(save_dir), dpi=250)