glenn-jocher
commited on
Commit
•
f639e14
1
Parent(s):
2a835c7
Metric-Confidence plots feature addition (#2057)
Browse files* Metric-Confidence plots feature addition
* cleanup
* Metric-Confidence plots feature addition
* cleanup
* Update run-once lines
* cleanup
* save all 4 curves to wandb
- test.py +1 -1
- train.py +1 -1
- utils/metrics.py +38 -15
test.py
CHANGED
@@ -215,7 +215,7 @@ def test(data,
|
|
215 |
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
216 |
if len(stats) and stats[0].any():
|
217 |
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
218 |
-
|
219 |
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
220 |
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
221 |
else:
|
|
|
215 |
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
216 |
if len(stats) and stats[0].any():
|
217 |
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
218 |
+
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
219 |
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
220 |
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
221 |
else:
|
train.py
CHANGED
@@ -403,7 +403,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
403 |
if plots:
|
404 |
plot_results(save_dir=save_dir) # save as results.png
|
405 |
if wandb:
|
406 |
-
files = ['results.png', '
|
407 |
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
408 |
if (save_dir / f).exists()]})
|
409 |
if opt.log_artifacts:
|
|
|
403 |
if plots:
|
404 |
plot_results(save_dir=save_dir) # save as results.png
|
405 |
if wandb:
|
406 |
+
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
407 |
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
408 |
if (save_dir / f).exists()]})
|
409 |
if opt.log_artifacts:
|
utils/metrics.py
CHANGED
@@ -15,7 +15,7 @@ def fitness(x):
|
|
15 |
return (x[:, :4] * w).sum(1)
|
16 |
|
17 |
|
18 |
-
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='
|
19 |
""" Compute the average precision, given the recall and precision curves.
|
20 |
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
21 |
# Arguments
|
@@ -35,12 +35,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision
|
|
35 |
|
36 |
# Find unique classes
|
37 |
unique_classes = np.unique(target_cls)
|
|
|
38 |
|
39 |
# Create Precision-Recall curve and compute AP for each class
|
40 |
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
41 |
-
|
42 |
-
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
|
43 |
-
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
|
44 |
for ci, c in enumerate(unique_classes):
|
45 |
i = pred_cls == c
|
46 |
n_l = (target_cls == c).sum() # number of labels
|
@@ -55,25 +54,28 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision
|
|
55 |
|
56 |
# Recall
|
57 |
recall = tpc / (n_l + 1e-16) # recall curve
|
58 |
-
r[ci] = np.interp(-
|
59 |
|
60 |
# Precision
|
61 |
precision = tpc / (tpc + fpc) # precision curve
|
62 |
-
p[ci] = np.interp(-
|
63 |
|
64 |
# AP from recall-precision curve
|
65 |
for j in range(tp.shape[1]):
|
66 |
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
67 |
-
if plot and
|
68 |
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
69 |
|
70 |
-
# Compute F1
|
71 |
f1 = 2 * p * r / (p + r + 1e-16)
|
72 |
-
|
73 |
if plot:
|
74 |
-
plot_pr_curve(px, py, ap, save_dir, names)
|
|
|
|
|
|
|
75 |
|
76 |
-
|
|
|
77 |
|
78 |
|
79 |
def compute_ap(recall, precision):
|
@@ -181,13 +183,14 @@ class ConfusionMatrix:
|
|
181 |
|
182 |
# Plots ----------------------------------------------------------------------------------------------------------------
|
183 |
|
184 |
-
def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
|
|
185 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
186 |
py = np.stack(py, axis=1)
|
187 |
|
188 |
-
if 0 < len(names) < 21: #
|
189 |
for i, y in enumerate(py.T):
|
190 |
-
ax.plot(px, y, linewidth=1, label=f'{names[i]}
|
191 |
else:
|
192 |
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
193 |
|
@@ -197,4 +200,24 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
|
197 |
ax.set_xlim(0, 1)
|
198 |
ax.set_ylim(0, 1)
|
199 |
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
200 |
-
fig.savefig(Path(save_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
return (x[:, :4] * w).sum(1)
|
16 |
|
17 |
|
18 |
+
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
|
19 |
""" Compute the average precision, given the recall and precision curves.
|
20 |
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
21 |
# Arguments
|
|
|
35 |
|
36 |
# Find unique classes
|
37 |
unique_classes = np.unique(target_cls)
|
38 |
+
nc = unique_classes.shape[0] # number of classes, number of detections
|
39 |
|
40 |
# Create Precision-Recall curve and compute AP for each class
|
41 |
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
42 |
+
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
|
|
|
|
43 |
for ci, c in enumerate(unique_classes):
|
44 |
i = pred_cls == c
|
45 |
n_l = (target_cls == c).sum() # number of labels
|
|
|
54 |
|
55 |
# Recall
|
56 |
recall = tpc / (n_l + 1e-16) # recall curve
|
57 |
+
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
58 |
|
59 |
# Precision
|
60 |
precision = tpc / (tpc + fpc) # precision curve
|
61 |
+
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
62 |
|
63 |
# AP from recall-precision curve
|
64 |
for j in range(tp.shape[1]):
|
65 |
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
66 |
+
if plot and j == 0:
|
67 |
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
68 |
|
69 |
+
# Compute F1 (harmonic mean of precision and recall)
|
70 |
f1 = 2 * p * r / (p + r + 1e-16)
|
|
|
71 |
if plot:
|
72 |
+
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
|
73 |
+
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
|
74 |
+
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
|
75 |
+
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
|
76 |
|
77 |
+
i = f1.mean(0).argmax() # max F1 index
|
78 |
+
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
|
79 |
|
80 |
|
81 |
def compute_ap(recall, precision):
|
|
|
183 |
|
184 |
# Plots ----------------------------------------------------------------------------------------------------------------
|
185 |
|
186 |
+
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
|
187 |
+
# Precision-recall curve
|
188 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
189 |
py = np.stack(py, axis=1)
|
190 |
|
191 |
+
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
192 |
for i, y in enumerate(py.T):
|
193 |
+
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
|
194 |
else:
|
195 |
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
196 |
|
|
|
200 |
ax.set_xlim(0, 1)
|
201 |
ax.set_ylim(0, 1)
|
202 |
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
203 |
+
fig.savefig(Path(save_dir), dpi=250)
|
204 |
+
|
205 |
+
|
206 |
+
def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
|
207 |
+
# Metric-confidence curve
|
208 |
+
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
209 |
+
|
210 |
+
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
211 |
+
for i, y in enumerate(py):
|
212 |
+
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
|
213 |
+
else:
|
214 |
+
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
|
215 |
+
|
216 |
+
y = py.mean(0)
|
217 |
+
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
218 |
+
ax.set_xlabel(xlabel)
|
219 |
+
ax.set_ylabel(ylabel)
|
220 |
+
ax.set_xlim(0, 1)
|
221 |
+
ax.set_ylim(0, 1)
|
222 |
+
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
223 |
+
fig.savefig(Path(save_dir), dpi=250)
|