glenn-jocher commited on
Commit
685d601
1 Parent(s): 49abc72

Increase plot_labels() speed (#1736)

Browse files
Files changed (2) hide show
  1. train.py +1 -1
  2. utils/plots.py +9 -17
train.py CHANGED
@@ -205,7 +205,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
205
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
206
  # model._initialize_biases(cf.to(device))
207
  if plots:
208
- Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
209
  if tb_writer:
210
  tb_writer.add_histogram('classes', c, 0)
211
 
 
205
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
206
  # model._initialize_biases(cf.to(device))
207
  if plots:
208
+ plot_labels(labels, save_dir, loggers)
209
  if tb_writer:
210
  tb_writer.add_histogram('classes', c, 0)
211
 
utils/plots.py CHANGED
@@ -11,6 +11,8 @@ import cv2
11
  import matplotlib
12
  import matplotlib.pyplot as plt
13
  import numpy as np
 
 
14
  import torch
15
  import yaml
16
  from PIL import Image, ImageDraw
@@ -253,34 +255,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
253
 
254
  def plot_labels(labels, save_dir=Path(''), loggers=None):
255
  # plot dataset labels
 
256
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
257
  nc = int(c.max() + 1) # number of classes
258
  colors = color_list()
 
259
 
260
  # seaborn correlogram
261
- try:
262
- import seaborn as sns
263
- import pandas as pd
264
- x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
265
- sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
266
- plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
267
- diag_kws=dict(bins=50))
268
- plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
269
- plt.close()
270
- except Exception as e:
271
- pass
272
 
273
  # matplotlib labels
274
  matplotlib.use('svg') # faster
275
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
276
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
277
  ax[0].set_xlabel('classes')
278
- ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
279
- ax[2].set_xlabel('x')
280
- ax[2].set_ylabel('y')
281
- ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
282
- ax[3].set_xlabel('width')
283
- ax[3].set_ylabel('height')
284
 
285
  # rectangles
286
  labels[:, 1:3] = 0.5 # center
 
11
  import matplotlib
12
  import matplotlib.pyplot as plt
13
  import numpy as np
14
+ import pandas as pd
15
+ import seaborn as sns
16
  import torch
17
  import yaml
18
  from PIL import Image, ImageDraw
 
255
 
256
  def plot_labels(labels, save_dir=Path(''), loggers=None):
257
  # plot dataset labels
258
+ print('Plotting labels... ')
259
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
260
  nc = int(c.max() + 1) # number of classes
261
  colors = color_list()
262
+ x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
263
 
264
  # seaborn correlogram
265
+ sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
266
+ plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
267
+ plt.close()
 
 
 
 
 
 
 
 
268
 
269
  # matplotlib labels
270
  matplotlib.use('svg') # faster
271
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
272
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
273
  ax[0].set_xlabel('classes')
274
+ sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
275
+ sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
 
 
 
 
276
 
277
  # rectangles
278
  labels[:, 1:3] = 0.5 # center