Sa-m commited on
Commit
25df9c3
1 Parent(s): 659b86c

Upload plots.py

Browse files
Files changed (1) hide show
  1. utils/plots.py +489 -0
utils/plots.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Plotting utils
2
+
3
+ import glob
4
+ import math
5
+ import os
6
+ import random
7
+ from copy import copy
8
+ from pathlib import Path
9
+
10
+ import cv2
11
+ import matplotlib
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import pandas as pd
15
+ import seaborn as sns
16
+ import torch
17
+ import yaml
18
+ from PIL import Image, ImageDraw, ImageFont
19
+ from scipy.signal import butter, filtfilt
20
+
21
+ from utils.general import xywh2xyxy, xyxy2xywh
22
+ from utils.metrics import fitness
23
+
24
+ # Settings
25
+ matplotlib.rc('font', **{'size': 11})
26
+ matplotlib.use('Agg') # for writing to files only
27
+
28
+
29
+ def color_list():
30
+ # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
31
+ def hex2rgb(h):
32
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
33
+
34
+ return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
35
+
36
+
37
+ def hist2d(x, y, n=100):
38
+ # 2d histogram used in labels.png and evolve.png
39
+ xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
40
+ hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
41
+ xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
42
+ yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
43
+ return np.log(hist[xidx, yidx])
44
+
45
+
46
+ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
47
+ # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
48
+ def butter_lowpass(cutoff, fs, order):
49
+ nyq = 0.5 * fs
50
+ normal_cutoff = cutoff / nyq
51
+ return butter(order, normal_cutoff, btype='low', analog=False)
52
+
53
+ b, a = butter_lowpass(cutoff, fs, order=order)
54
+ return filtfilt(b, a, data) # forward-backward filter
55
+
56
+
57
+ def plot_one_box(x, img, color=None, label=None, line_thickness=3):
58
+ # Plots one bounding box on image img
59
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
60
+ color = color or [random.randint(0, 255) for _ in range(3)]
61
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
62
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
63
+ if label:
64
+ tf = max(tl - 1, 1) # font thickness
65
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
66
+ c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
67
+ cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
68
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
69
+
70
+
71
+ def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None):
72
+ img = Image.fromarray(img)
73
+ draw = ImageDraw.Draw(img)
74
+ line_thickness = line_thickness or max(int(min(img.size) / 200), 2)
75
+ draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot
76
+ if label:
77
+ fontsize = max(round(max(img.size) / 40), 12)
78
+ font = ImageFont.truetype("Arial.ttf", fontsize)
79
+ txt_width, txt_height = font.getsize(label)
80
+ draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
81
+ draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
82
+ return np.asarray(img)
83
+
84
+
85
+ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
86
+ # Compares the two methods for width-height anchor multiplication
87
+ # https://github.com/ultralytics/yolov3/issues/168
88
+ x = np.arange(-4.0, 4.0, .1)
89
+ ya = np.exp(x)
90
+ yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
91
+
92
+ fig = plt.figure(figsize=(6, 3), tight_layout=True)
93
+ plt.plot(x, ya, '.-', label='YOLOv3')
94
+ plt.plot(x, yb ** 2, '.-', label='YOLOR ^2')
95
+ plt.plot(x, yb ** 1.6, '.-', label='YOLOR ^1.6')
96
+ plt.xlim(left=-4, right=4)
97
+ plt.ylim(bottom=0, top=6)
98
+ plt.xlabel('input')
99
+ plt.ylabel('output')
100
+ plt.grid()
101
+ plt.legend()
102
+ fig.savefig('comparison.png', dpi=200)
103
+
104
+
105
+ def output_to_target(output):
106
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
107
+ targets = []
108
+ for i, o in enumerate(output):
109
+ for *box, conf, cls in o.cpu().numpy():
110
+ targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
111
+ return np.array(targets)
112
+
113
+
114
+ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
115
+ # Plot image grid with labels
116
+
117
+ if isinstance(images, torch.Tensor):
118
+ images = images.cpu().float().numpy()
119
+ if isinstance(targets, torch.Tensor):
120
+ targets = targets.cpu().numpy()
121
+
122
+ # un-normalise
123
+ if np.max(images[0]) <= 1:
124
+ images *= 255
125
+
126
+ tl = 3 # line thickness
127
+ tf = max(tl - 1, 1) # font thickness
128
+ bs, _, h, w = images.shape # batch size, _, height, width
129
+ bs = min(bs, max_subplots) # limit plot images
130
+ ns = np.ceil(bs ** 0.5) # number of subplots (square)
131
+
132
+ # Check if we should resize
133
+ scale_factor = max_size / max(h, w)
134
+ if scale_factor < 1:
135
+ h = math.ceil(scale_factor * h)
136
+ w = math.ceil(scale_factor * w)
137
+
138
+ colors = color_list() # list of colors
139
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
140
+ for i, img in enumerate(images):
141
+ if i == max_subplots: # if last batch has fewer images than we expect
142
+ break
143
+
144
+ block_x = int(w * (i // ns))
145
+ block_y = int(h * (i % ns))
146
+
147
+ img = img.transpose(1, 2, 0)
148
+ if scale_factor < 1:
149
+ img = cv2.resize(img, (w, h))
150
+
151
+ mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
152
+ if len(targets) > 0:
153
+ image_targets = targets[targets[:, 0] == i]
154
+ boxes = xywh2xyxy(image_targets[:, 2:6]).T
155
+ classes = image_targets[:, 1].astype('int')
156
+ labels = image_targets.shape[1] == 6 # labels if no conf column
157
+ conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
158
+
159
+ if boxes.shape[1]:
160
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
161
+ boxes[[0, 2]] *= w # scale to pixels
162
+ boxes[[1, 3]] *= h
163
+ elif scale_factor < 1: # absolute coords need scale if image scales
164
+ boxes *= scale_factor
165
+ boxes[[0, 2]] += block_x
166
+ boxes[[1, 3]] += block_y
167
+ for j, box in enumerate(boxes.T):
168
+ cls = int(classes[j])
169
+ color = colors[cls % len(colors)]
170
+ cls = names[cls] if names else cls
171
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
172
+ label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
173
+ plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
174
+
175
+ # Draw image filename labels
176
+ if paths:
177
+ label = Path(paths[i]).name[:40] # trim to 40 char
178
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
179
+ cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
180
+ lineType=cv2.LINE_AA)
181
+
182
+ # Image border
183
+ cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
184
+
185
+ if fname:
186
+ r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
187
+ mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
188
+ # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
189
+ Image.fromarray(mosaic).save(fname) # PIL save
190
+ return mosaic
191
+
192
+
193
+ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
194
+ # Plot LR simulating training for full epochs
195
+ optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
196
+ y = []
197
+ for _ in range(epochs):
198
+ scheduler.step()
199
+ y.append(optimizer.param_groups[0]['lr'])
200
+ plt.plot(y, '.-', label='LR')
201
+ plt.xlabel('epoch')
202
+ plt.ylabel('LR')
203
+ plt.grid()
204
+ plt.xlim(0, epochs)
205
+ plt.ylim(0)
206
+ plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
207
+ plt.close()
208
+
209
+
210
+ def plot_test_txt(): # from utils.plots import *; plot_test()
211
+ # Plot test.txt histograms
212
+ x = np.loadtxt('test.txt', dtype=np.float32)
213
+ box = xyxy2xywh(x[:, :4])
214
+ cx, cy = box[:, 0], box[:, 1]
215
+
216
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
217
+ ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
218
+ ax.set_aspect('equal')
219
+ plt.savefig('hist2d.png', dpi=300)
220
+
221
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
222
+ ax[0].hist(cx, bins=600)
223
+ ax[1].hist(cy, bins=600)
224
+ plt.savefig('hist1d.png', dpi=200)
225
+
226
+
227
+ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
228
+ # Plot targets.txt histograms
229
+ x = np.loadtxt('targets.txt', dtype=np.float32).T
230
+ s = ['x targets', 'y targets', 'width targets', 'height targets']
231
+ fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
232
+ ax = ax.ravel()
233
+ for i in range(4):
234
+ ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
235
+ ax[i].legend()
236
+ ax[i].set_title(s[i])
237
+ plt.savefig('targets.jpg', dpi=200)
238
+
239
+
240
+ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
241
+ # Plot study.txt generated by test.py
242
+ fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
243
+ # ax = ax.ravel()
244
+
245
+ fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
246
+ # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolor-p6', 'yolor-w6', 'yolor-e6', 'yolor-d6']]:
247
+ for f in sorted(Path(path).glob('study*.txt')):
248
+ y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
249
+ x = np.arange(y.shape[1]) if x is None else np.array(x)
250
+ s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
251
+ # for i in range(7):
252
+ # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
253
+ # ax[i].set_title(s[i])
254
+
255
+ j = y[3].argmax() + 1
256
+ ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
257
+ label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
258
+
259
+ ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
260
+ 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
261
+
262
+ ax2.grid(alpha=0.2)
263
+ ax2.set_yticks(np.arange(20, 60, 5))
264
+ ax2.set_xlim(0, 57)
265
+ ax2.set_ylim(30, 55)
266
+ ax2.set_xlabel('GPU Speed (ms/img)')
267
+ ax2.set_ylabel('COCO AP val')
268
+ ax2.legend(loc='lower right')
269
+ plt.savefig(str(Path(path).name) + '.png', dpi=300)
270
+
271
+
272
+ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
273
+ # plot dataset labels
274
+ print('Plotting labels... ')
275
+ c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
276
+ nc = int(c.max() + 1) # number of classes
277
+ colors = color_list()
278
+ x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
279
+
280
+ # seaborn correlogram
281
+ sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
282
+ plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
283
+ plt.close()
284
+
285
+ # matplotlib labels
286
+ matplotlib.use('svg') # faster
287
+ ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
288
+ ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
289
+ ax[0].set_ylabel('instances')
290
+ if 0 < len(names) < 30:
291
+ ax[0].set_xticks(range(len(names)))
292
+ ax[0].set_xticklabels(names, rotation=90, fontsize=10)
293
+ else:
294
+ ax[0].set_xlabel('classes')
295
+ sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
296
+ sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
297
+
298
+ # rectangles
299
+ labels[:, 1:3] = 0.5 # center
300
+ labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
301
+ img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
302
+ for cls, *box in labels[:1000]:
303
+ ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
304
+ ax[1].imshow(img)
305
+ ax[1].axis('off')
306
+
307
+ for a in [0, 1, 2, 3]:
308
+ for s in ['top', 'right', 'left', 'bottom']:
309
+ ax[a].spines[s].set_visible(False)
310
+
311
+ plt.savefig(save_dir / 'labels.jpg', dpi=200)
312
+ matplotlib.use('Agg')
313
+ plt.close()
314
+
315
+ # loggers
316
+ for k, v in loggers.items() or {}:
317
+ if k == 'wandb' and v:
318
+ v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
319
+
320
+
321
+ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
322
+ # Plot hyperparameter evolution results in evolve.txt
323
+ with open(yaml_file) as f:
324
+ hyp = yaml.load(f, Loader=yaml.SafeLoader)
325
+ x = np.loadtxt('evolve.txt', ndmin=2)
326
+ f = fitness(x)
327
+ # weights = (f - f.min()) ** 2 # for weighted results
328
+ plt.figure(figsize=(10, 12), tight_layout=True)
329
+ matplotlib.rc('font', **{'size': 8})
330
+ for i, (k, v) in enumerate(hyp.items()):
331
+ y = x[:, i + 7]
332
+ # mu = (y * weights).sum() / weights.sum() # best weighted result
333
+ mu = y[f.argmax()] # best single result
334
+ plt.subplot(6, 5, i + 1)
335
+ plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
336
+ plt.plot(mu, f.max(), 'k+', markersize=15)
337
+ plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
338
+ if i % 5 != 0:
339
+ plt.yticks([])
340
+ print('%15s: %.3g' % (k, mu))
341
+ plt.savefig('evolve.png', dpi=200)
342
+ print('\nPlot saved as evolve.png')
343
+
344
+
345
+ def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
346
+ # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
347
+ ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
348
+ s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
349
+ files = list(Path(save_dir).glob('frames*.txt'))
350
+ for fi, f in enumerate(files):
351
+ try:
352
+ results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
353
+ n = results.shape[1] # number of rows
354
+ x = np.arange(start, min(stop, n) if stop else n)
355
+ results = results[:, x]
356
+ t = (results[0] - results[0].min()) # set t0=0s
357
+ results[0] = x
358
+ for i, a in enumerate(ax):
359
+ if i < len(results):
360
+ label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
361
+ a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
362
+ a.set_title(s[i])
363
+ a.set_xlabel('time (s)')
364
+ # if fi == len(files) - 1:
365
+ # a.set_ylim(bottom=0)
366
+ for side in ['top', 'right']:
367
+ a.spines[side].set_visible(False)
368
+ else:
369
+ a.remove()
370
+ except Exception as e:
371
+ print('Warning: Plotting error for %s; %s' % (f, e))
372
+
373
+ ax[1].legend()
374
+ plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
375
+
376
+
377
+ def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
378
+ # Plot training 'results*.txt', overlaying train and val losses
379
+ s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
380
+ t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
381
+ for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
382
+ results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
383
+ n = results.shape[1] # number of rows
384
+ x = range(start, min(stop, n) if stop else n)
385
+ fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
386
+ ax = ax.ravel()
387
+ for i in range(5):
388
+ for j in [i, i + 5]:
389
+ y = results[j, x]
390
+ ax[i].plot(x, y, marker='.', label=s[j])
391
+ # y_smooth = butter_lowpass_filtfilt(y)
392
+ # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
393
+
394
+ ax[i].set_title(t[i])
395
+ ax[i].legend()
396
+ ax[i].set_ylabel(f) if i == 0 else None # add filename
397
+ fig.savefig(f.replace('.txt', '.png'), dpi=200)
398
+
399
+
400
+ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
401
+ # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
402
+ fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
403
+ ax = ax.ravel()
404
+ s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
405
+ 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
406
+ if bucket:
407
+ # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
408
+ files = ['results%g.txt' % x for x in id]
409
+ c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
410
+ os.system(c)
411
+ else:
412
+ files = list(Path(save_dir).glob('results*.txt'))
413
+ assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
414
+ for fi, f in enumerate(files):
415
+ try:
416
+ results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
417
+ n = results.shape[1] # number of rows
418
+ x = range(start, min(stop, n) if stop else n)
419
+ for i in range(10):
420
+ y = results[i, x]
421
+ if i in [0, 1, 2, 5, 6, 7]:
422
+ y[y == 0] = np.nan # don't show zero loss values
423
+ # y /= y[0] # normalize
424
+ label = labels[fi] if len(labels) else f.stem
425
+ ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
426
+ ax[i].set_title(s[i])
427
+ # if i in [5, 6, 7]: # share train and val loss y axes
428
+ # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
429
+ except Exception as e:
430
+ print('Warning: Plotting error for %s; %s' % (f, e))
431
+
432
+ ax[1].legend()
433
+ fig.savefig(Path(save_dir) / 'results.png', dpi=200)
434
+
435
+
436
+ def output_to_keypoint(output):
437
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
438
+ targets = []
439
+ for i, o in enumerate(output):
440
+ kpts = o[:,6:]
441
+ o = o[:,:6]
442
+ for index, (*box, conf, cls) in enumerate(o.detach().cpu().numpy()):
443
+ targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf, *list(kpts.detach().cpu().numpy()[index])])
444
+ return np.array(targets)
445
+
446
+
447
+ def plot_skeleton_kpts(im, kpts, steps, orig_shape=None):
448
+ #Plot the skeleton and keypointsfor coco datatset
449
+ palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
450
+ [230, 230, 0], [255, 153, 255], [153, 204, 255],
451
+ [255, 102, 255], [255, 51, 255], [102, 178, 255],
452
+ [51, 153, 255], [255, 153, 153], [255, 102, 102],
453
+ [255, 51, 51], [153, 255, 153], [102, 255, 102],
454
+ [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
455
+ [255, 255, 255]])
456
+
457
+ skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
458
+ [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
459
+ [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
460
+
461
+ pose_limb_color = palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
462
+ pose_kpt_color = palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
463
+ radius = 5
464
+ num_kpts = len(kpts) // steps
465
+
466
+ for kid in range(num_kpts):
467
+ r, g, b = pose_kpt_color[kid]
468
+ x_coord, y_coord = kpts[steps * kid], kpts[steps * kid + 1]
469
+ if not (x_coord % 640 == 0 or y_coord % 640 == 0):
470
+ if steps == 3:
471
+ conf = kpts[steps * kid + 2]
472
+ if conf < 0.5:
473
+ continue
474
+ cv2.circle(im, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1)
475
+
476
+ for sk_id, sk in enumerate(skeleton):
477
+ r, g, b = pose_limb_color[sk_id]
478
+ pos1 = (int(kpts[(sk[0]-1)*steps]), int(kpts[(sk[0]-1)*steps+1]))
479
+ pos2 = (int(kpts[(sk[1]-1)*steps]), int(kpts[(sk[1]-1)*steps+1]))
480
+ if steps == 3:
481
+ conf1 = kpts[(sk[0]-1)*steps+2]
482
+ conf2 = kpts[(sk[1]-1)*steps+2]
483
+ if conf1<0.5 or conf2<0.5:
484
+ continue
485
+ if pos1[0]%640 == 0 or pos1[1]%640==0 or pos1[0]<0 or pos1[1]<0:
486
+ continue
487
+ if pos2[0] % 640 == 0 or pos2[1] % 640 == 0 or pos2[0]<0 or pos2[1]<0:
488
+ continue
489
+ cv2.line(im, pos1, pos2, (int(r), int(g), int(b)), thickness=2)