glenn-jocher pre-commit-ci[bot] commited on
Commit
4a295b1
1 Parent(s): 9d8ed37

Add `@threaded` decorator (#7813)

Browse files

* Add `@threaded` decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (5) hide show
  1. train.py +2 -2
  2. utils/general.py +11 -0
  3. utils/loggers/__init__.py +3 -4
  4. utils/plots.py +7 -6
  5. val.py +2 -5
train.py CHANGED
@@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader
48
  from utils.downloads import attempt_download
49
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
  check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
51
- init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
52
- methods, one_cycle, print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
 
48
  from utils.downloads import attempt_download
49
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
  check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
51
+ init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
52
+ one_cycle, print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
utils/general.py CHANGED
@@ -14,6 +14,7 @@ import random
14
  import re
15
  import shutil
16
  import signal
 
17
  import time
18
  import urllib
19
  from datetime import datetime
@@ -167,6 +168,16 @@ def try_except(func):
167
  return handler
168
 
169
 
 
 
 
 
 
 
 
 
 
 
170
  def methods(instance):
171
  # Get class/instance methods
172
  return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
 
14
  import re
15
  import shutil
16
  import signal
17
+ import threading
18
  import time
19
  import urllib
20
  from datetime import datetime
 
168
  return handler
169
 
170
 
171
+ def threaded(func):
172
+ # Multi-threads a target function and returns thread. Usage: @threaded decorator
173
+ def wrapper(*args, **kwargs):
174
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
175
+ thread.start()
176
+ return thread
177
+
178
+ return wrapper
179
+
180
+
181
  def methods(instance):
182
  # Get class/instance methods
183
  return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
utils/loggers/__init__.py CHANGED
@@ -5,7 +5,6 @@ Logging utils
5
 
6
  import os
7
  import warnings
8
- from threading import Thread
9
 
10
  import pkg_resources as pkg
11
  import torch
@@ -109,7 +108,7 @@ class Loggers():
109
  self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
110
  if ni < 3:
111
  f = self.save_dir / f'train_batch{ni}.jpg' # filename
112
- Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
113
  if self.wandb and ni == 10:
114
  files = sorted(self.save_dir.glob('train*.jpg'))
115
  self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
@@ -132,7 +131,7 @@ class Loggers():
132
 
133
  def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
134
  # Callback runs at the end of each fit (train+val) epoch
135
- x = {k: v for k, v in zip(self.keys, vals)} # dict
136
  if self.csv:
137
  file = self.save_dir / 'results.csv'
138
  n = len(x) + 1 # number of cols
@@ -171,7 +170,7 @@ class Loggers():
171
  self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
172
 
173
  if self.wandb:
174
- self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results
175
  self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
176
  # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
177
  if not self.opt.evolve:
 
5
 
6
  import os
7
  import warnings
 
8
 
9
  import pkg_resources as pkg
10
  import torch
 
108
  self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
109
  if ni < 3:
110
  f = self.save_dir / f'train_batch{ni}.jpg' # filename
111
+ plot_images(imgs, targets, paths, f)
112
  if self.wandb and ni == 10:
113
  files = sorted(self.save_dir.glob('train*.jpg'))
114
  self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
 
131
 
132
  def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
133
  # Callback runs at the end of each fit (train+val) epoch
134
+ x = dict(zip(self.keys, vals))
135
  if self.csv:
136
  file = self.save_dir / 'results.csv'
137
  n = len(x) + 1 # number of cols
 
170
  self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
171
 
172
  if self.wandb:
173
+ self.wandb.log(dict(zip(self.keys[3:10], results)))
174
  self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
175
  # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
176
  if not self.opt.evolve:
utils/plots.py CHANGED
@@ -19,7 +19,7 @@ import torch
19
  from PIL import Image, ImageDraw, ImageFont
20
 
21
  from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
22
- increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
23
  from utils.metrics import fitness
24
 
25
  # Settings
@@ -32,9 +32,9 @@ class Colors:
32
  # Ultralytics color palette https://ultralytics.com/
33
  def __init__(self):
34
  # hex = matplotlib.colors.TABLEAU_COLORS.values()
35
- hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
36
- '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
37
- self.palette = [self.hex2rgb('#' + c) for c in hex]
38
  self.n = len(self.palette)
39
 
40
  def __call__(self, i, bgr=False):
@@ -100,7 +100,7 @@ class Annotator:
100
  if label:
101
  tf = max(self.lw - 1, 1) # font thickness
102
  w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
103
- outside = p1[1] - h - 3 >= 0 # label fits outside box
104
  p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
105
  cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
106
  cv2.putText(self.im,
@@ -184,6 +184,7 @@ def output_to_target(output):
184
  return np.array(targets)
185
 
186
 
 
187
  def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
188
  # Plot image grid with labels
189
  if isinstance(images, torch.Tensor):
@@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''):
420
  ax = ax.ravel()
421
  files = list(save_dir.glob('results*.csv'))
422
  assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
423
- for fi, f in enumerate(files):
424
  try:
425
  data = pd.read_csv(f)
426
  s = [x.strip() for x in data.columns]
 
19
  from PIL import Image, ImageDraw, ImageFont
20
 
21
  from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
22
+ increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
23
  from utils.metrics import fitness
24
 
25
  # Settings
 
32
  # Ultralytics color palette https://ultralytics.com/
33
  def __init__(self):
34
  # hex = matplotlib.colors.TABLEAU_COLORS.values()
35
+ hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
36
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
37
+ self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
38
  self.n = len(self.palette)
39
 
40
  def __call__(self, i, bgr=False):
 
100
  if label:
101
  tf = max(self.lw - 1, 1) # font thickness
102
  w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
103
+ outside = p1[1] - h >= 3
104
  p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
105
  cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
106
  cv2.putText(self.im,
 
184
  return np.array(targets)
185
 
186
 
187
+ @threaded
188
  def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
189
  # Plot image grid with labels
190
  if isinstance(images, torch.Tensor):
 
421
  ax = ax.ravel()
422
  files = list(save_dir.glob('results*.csv'))
423
  assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
424
+ for f in files:
425
  try:
426
  data = pd.read_csv(f)
427
  s = [x.strip() for x in data.columns]
val.py CHANGED
@@ -23,7 +23,6 @@ import json
23
  import os
24
  import sys
25
  from pathlib import Path
26
- from threading import Thread
27
 
28
  import numpy as np
29
  import torch
@@ -255,10 +254,8 @@ def run(
255
 
256
  # Plot images
257
  if plots and batch_i < 3:
258
- f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
259
- Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
260
- f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
261
- Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
262
 
263
  callbacks.run('on_val_batch_end')
264
 
 
23
  import os
24
  import sys
25
  from pathlib import Path
 
26
 
27
  import numpy as np
28
  import torch
 
254
 
255
  # Plot images
256
  if plots and batch_i < 3:
257
+ plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
258
+ plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
 
 
259
 
260
  callbacks.run('on_val_batch_end')
261