glenn-jocher commited on
Commit
1d1c056
·
unverified ·
1 Parent(s): ffef771

PyTorch Hub results.render() (#1897)

Browse files
Files changed (3) hide show
  1. models/common.py +13 -6
  2. train.py +6 -5
  3. utils/general.py +4 -3
models/common.py CHANGED
@@ -1,6 +1,7 @@
1
  # This file contains modules common to various models
2
 
3
  import math
 
4
  import numpy as np
5
  import requests
6
  import torch
@@ -240,7 +241,7 @@ class Detections:
240
  self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
241
  self.n = len(self.pred)
242
 
243
- def display(self, pprint=False, show=False, save=False):
244
  colors = color_list()
245
  for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
246
  str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
@@ -248,19 +249,21 @@ class Detections:
248
  for c in pred[:, -1].unique():
249
  n = (pred[:, -1] == c).sum() # detections per class
250
  str += f'{n} {self.names[int(c)]}s, ' # add to string
251
- if show or save:
252
  img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
253
  for *box, conf, cls in pred: # xyxy, confidence, class
254
  # str += '%s %.2f, ' % (names[int(cls)], conf) # label
255
  ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
 
 
 
 
256
  if save:
257
  f = f'results{i}.jpg'
258
  str += f"saved to '{f}'"
259
  img.save(f) # save
260
- if show:
261
- img.show(f'Image {i}') # show
262
- if pprint:
263
- print(str)
264
 
265
  def print(self):
266
  self.display(pprint=True) # print results
@@ -271,6 +274,10 @@ class Detections:
271
  def save(self):
272
  self.display(save=True) # save results
273
 
 
 
 
 
274
  def __len__(self):
275
  return self.n
276
 
 
1
  # This file contains modules common to various models
2
 
3
  import math
4
+
5
  import numpy as np
6
  import requests
7
  import torch
 
241
  self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
242
  self.n = len(self.pred)
243
 
244
+ def display(self, pprint=False, show=False, save=False, render=False):
245
  colors = color_list()
246
  for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
247
  str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
 
249
  for c in pred[:, -1].unique():
250
  n = (pred[:, -1] == c).sum() # detections per class
251
  str += f'{n} {self.names[int(c)]}s, ' # add to string
252
+ if show or save or render:
253
  img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
254
  for *box, conf, cls in pred: # xyxy, confidence, class
255
  # str += '%s %.2f, ' % (names[int(cls)], conf) # label
256
  ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
257
+ if pprint:
258
+ print(str)
259
+ if show:
260
+ img.show(f'Image {i}') # show
261
  if save:
262
  f = f'results{i}.jpg'
263
  str += f"saved to '{f}'"
264
  img.save(f) # save
265
+ if render:
266
+ self.imgs[i] = np.asarray(img)
 
 
267
 
268
  def print(self):
269
  self.display(pprint=True) # print results
 
274
  def save(self):
275
  self.display(save=True) # save results
276
 
277
+ def render(self):
278
+ self.display(render=True) # render results
279
+ return self.imgs
280
+
281
  def __len__(self):
282
  return self.n
283
 
train.py CHANGED
@@ -28,7 +28,7 @@ from utils.autoanchor import check_anchors
28
  from utils.datasets import create_dataloader
29
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
30
  fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
31
- check_requirements, print_mutation, set_logging, one_cycle
32
  from utils.google_utils import attempt_download
33
  from utils.loss import compute_loss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
@@ -44,7 +44,7 @@ except ImportError:
44
 
45
 
46
  def train(hyp, opt, device, tb_writer=None, wandb=None):
47
- logger.info(f'Hyperparameters {hyp}')
48
  save_dir, epochs, batch_size, total_batch_size, weights, rank = \
49
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
50
 
@@ -233,9 +233,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
233
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
234
  scheduler.last_epoch = start_epoch - 1 # do not move
235
  scaler = amp.GradScaler(enabled=cuda)
236
- logger.info('Image sizes %g train, %g test\n'
237
- 'Using %g dataloader workers\nLogging results to %s\n'
238
- 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
 
239
  for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
240
  model.train()
241
 
 
28
  from utils.datasets import create_dataloader
29
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
30
  fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
31
+ check_requirements, print_mutation, set_logging, one_cycle, colorstr
32
  from utils.google_utils import attempt_download
33
  from utils.loss import compute_loss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
 
44
 
45
 
46
  def train(hyp, opt, device, tb_writer=None, wandb=None):
47
+ logger.info(colorstr('blue', 'bold', 'Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
48
  save_dir, epochs, batch_size, total_batch_size, weights, rank = \
49
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
50
 
 
233
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
234
  scheduler.last_epoch = start_epoch - 1 # do not move
235
  scaler = amp.GradScaler(enabled=cuda)
236
+ logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
237
+ f'Using {dataloader.num_workers} dataloader workers\n'
238
+ f'Logging results to {save_dir}\n'
239
+ f'Starting training for {epochs} epochs...')
240
  for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
241
  model.train()
242
 
utils/general.py CHANGED
@@ -25,6 +25,7 @@ from utils.torch_utils import init_torch_seeds
25
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
26
  np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
27
  cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 
28
 
29
 
30
  def set_logging(rank=-1):
@@ -117,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
117
 
118
  def colorstr(*input):
119
  # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
120
- *prefix, str = input # color arguments, string
121
  colors = {'black': '\033[30m', # basic colors
122
  'red': '\033[31m',
123
  'green': '\033[32m',
@@ -136,9 +137,9 @@ def colorstr(*input):
136
  'bright_white': '\033[97m',
137
  'end': '\033[0m', # misc
138
  'bold': '\033[1m',
139
- 'undelrine': '\033[4m'}
140
 
141
- return ''.join(colors[x] for x in prefix) + str + colors['end']
142
 
143
 
144
  def labels_to_class_weights(labels, nc=80):
 
25
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
26
  np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
27
  cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
28
+ os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
29
 
30
 
31
  def set_logging(rank=-1):
 
118
 
119
  def colorstr(*input):
120
  # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
121
+ *prefix, string = input # color arguments, string
122
  colors = {'black': '\033[30m', # basic colors
123
  'red': '\033[31m',
124
  'green': '\033[32m',
 
137
  'bright_white': '\033[97m',
138
  'end': '\033[0m', # misc
139
  'bold': '\033[1m',
140
+ 'underline': '\033[4m'}
141
 
142
+ return ''.join(colors[x] for x in prefix) + f'{string}' + colors['end']
143
 
144
 
145
  def labels_to_class_weights(labels, nc=80):