glenn-jocher commited on
Commit
b6ed110
1 Parent(s): 68211f7

Daemon thread plotting (#1561)

Browse files

* Daemon thread plotting

* remove process_batch

* plot after print

Files changed (3) hide show
  1. test.py +12 -11
  2. train.py +5 -5
  3. utils/plots.py +8 -3
test.py CHANGED
@@ -3,6 +3,7 @@ import glob
3
  import json
4
  import os
5
  from pathlib import Path
 
6
 
7
  import numpy as np
8
  import torch
@@ -206,10 +207,10 @@ def test(data,
206
 
207
  # Plot images
208
  if plots and batch_i < 3:
209
- f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename
210
- plot_images(img, targets, paths, f, names) # labels
211
- f = save_dir / f'test_batch{batch_i}_pred.jpg'
212
- plot_images(img, output_to_target(output), paths, f, names) # predictions
213
 
214
  # Compute statistics
215
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
@@ -221,13 +222,6 @@ def test(data,
221
  else:
222
  nt = torch.zeros(1)
223
 
224
- # Plots
225
- if plots:
226
- confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
227
- if wandb and wandb.run:
228
- wandb.log({"Images": wandb_images})
229
- wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
230
-
231
  # Print results
232
  pf = '%20s' + '%12.3g' * 6 # print format
233
  print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
@@ -242,6 +236,13 @@ def test(data,
242
  if not training:
243
  print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
244
 
 
 
 
 
 
 
 
245
  # Save JSON
246
  if save_json and len(jdict):
247
  w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
 
3
  import json
4
  import os
5
  from pathlib import Path
6
+ from threading import Thread
7
 
8
  import numpy as np
9
  import torch
 
207
 
208
  # Plot images
209
  if plots and batch_i < 3:
210
+ f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
211
+ Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
212
+ f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
213
+ Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
214
 
215
  # Compute statistics
216
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
 
222
  else:
223
  nt = torch.zeros(1)
224
 
 
 
 
 
 
 
 
225
  # Print results
226
  pf = '%20s' + '%12.3g' * 6 # print format
227
  print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
 
236
  if not training:
237
  print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
238
 
239
+ # Plots
240
+ if plots:
241
+ confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
242
+ if wandb and wandb.run:
243
+ wandb.log({"Images": wandb_images})
244
+ wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
245
+
246
  # Save JSON
247
  if save_json and len(jdict):
248
  w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
train.py CHANGED
@@ -1,12 +1,13 @@
1
  import argparse
2
  import logging
3
- import math
4
  import os
5
  import random
6
  import time
7
  from pathlib import Path
 
8
  from warnings import warn
9
 
 
10
  import numpy as np
11
  import torch.distributed as dist
12
  import torch.nn as nn
@@ -134,6 +135,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
134
  project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
135
  name=save_dir.stem,
136
  id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
 
137
 
138
  # Resume
139
  start_epoch, best_fitness = 0, 0.0
@@ -201,11 +203,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
201
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
202
  # model._initialize_biases(cf.to(device))
203
  if plots:
204
- plot_labels(labels, save_dir=save_dir)
205
  if tb_writer:
206
  tb_writer.add_histogram('classes', c, 0)
207
- if wandb:
208
- wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
209
 
210
  # Anchors
211
  if not opt.noautoanchor:
@@ -311,7 +311,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
311
  # Plot
312
  if plots and ni < 3:
313
  f = save_dir / f'train_batch{ni}.jpg' # filename
314
- plot_images(images=imgs, targets=targets, paths=paths, fname=f)
315
  # if tb_writer:
316
  # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
317
  # tb_writer.add_graph(model, imgs) # add model to tensorboard
 
1
  import argparse
2
  import logging
 
3
  import os
4
  import random
5
  import time
6
  from pathlib import Path
7
+ from threading import Thread
8
  from warnings import warn
9
 
10
+ import math
11
  import numpy as np
12
  import torch.distributed as dist
13
  import torch.nn as nn
 
135
  project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
136
  name=save_dir.stem,
137
  id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
138
+ loggers = {'wandb': wandb} # loggers dict
139
 
140
  # Resume
141
  start_epoch, best_fitness = 0, 0.0
 
203
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
204
  # model._initialize_biases(cf.to(device))
205
  if plots:
206
+ Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
207
  if tb_writer:
208
  tb_writer.add_histogram('classes', c, 0)
 
 
209
 
210
  # Anchors
211
  if not opt.noautoanchor:
 
311
  # Plot
312
  if plots and ni < 3:
313
  f = save_dir / f'train_batch{ni}.jpg' # filename
314
+ Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
315
  # if tb_writer:
316
  # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
317
  # tb_writer.add_graph(model, imgs) # add model to tensorboard
utils/plots.py CHANGED
@@ -250,7 +250,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
250
  plt.savefig('test_study.png', dpi=300)
251
 
252
 
253
- def plot_labels(labels, save_dir=''):
254
  # plot dataset labels
255
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
256
  nc = int(c.max() + 1) # number of classes
@@ -264,7 +264,7 @@ def plot_labels(labels, save_dir=''):
264
  sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
265
  plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
266
  diag_kws=dict(bins=50))
267
- plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)
268
  plt.close()
269
  except Exception as e:
270
  pass
@@ -292,9 +292,14 @@ def plot_labels(labels, save_dir=''):
292
  for a in [0, 1, 2, 3]:
293
  for s in ['top', 'right', 'left', 'bottom']:
294
  ax[a].spines[s].set_visible(False)
295
- plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
296
  plt.close()
297
 
 
 
 
 
 
298
 
299
  def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
300
  # Plot hyperparameter evolution results in evolve.txt
 
250
  plt.savefig('test_study.png', dpi=300)
251
 
252
 
253
+ def plot_labels(labels, save_dir=Path(''), loggers=None):
254
  # plot dataset labels
255
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
256
  nc = int(c.max() + 1) # number of classes
 
264
  sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
265
  plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
266
  diag_kws=dict(bins=50))
267
+ plt.savefig(save_dir / 'labels_correlogram.png', dpi=200)
268
  plt.close()
269
  except Exception as e:
270
  pass
 
292
  for a in [0, 1, 2, 3]:
293
  for s in ['top', 'right', 'left', 'bottom']:
294
  ax[a].spines[s].set_visible(False)
295
+ plt.savefig(save_dir / 'labels.png', dpi=200)
296
  plt.close()
297
 
298
+ # loggers
299
+ for k, v in loggers.items() or {}:
300
+ if k == 'wandb' and v:
301
+ v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
302
+
303
 
304
  def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
305
  # Plot hyperparameter evolution results in evolve.txt