glenn-jocher
commited on
PyTorch Hub results.render() (#1897)
Browse files- models/common.py +13 -6
- train.py +6 -5
- utils/general.py +4 -3
models/common.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# This file contains modules common to various models
|
2 |
|
3 |
import math
|
|
|
4 |
import numpy as np
|
5 |
import requests
|
6 |
import torch
|
@@ -240,7 +241,7 @@ class Detections:
|
|
240 |
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
241 |
self.n = len(self.pred)
|
242 |
|
243 |
-
def display(self, pprint=False, show=False, save=False):
|
244 |
colors = color_list()
|
245 |
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
|
246 |
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
|
@@ -248,19 +249,21 @@ class Detections:
|
|
248 |
for c in pred[:, -1].unique():
|
249 |
n = (pred[:, -1] == c).sum() # detections per class
|
250 |
str += f'{n} {self.names[int(c)]}s, ' # add to string
|
251 |
-
if show or save:
|
252 |
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
|
253 |
for *box, conf, cls in pred: # xyxy, confidence, class
|
254 |
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
|
255 |
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
|
|
|
|
|
|
|
|
|
256 |
if save:
|
257 |
f = f'results{i}.jpg'
|
258 |
str += f"saved to '{f}'"
|
259 |
img.save(f) # save
|
260 |
-
if
|
261 |
-
|
262 |
-
if pprint:
|
263 |
-
print(str)
|
264 |
|
265 |
def print(self):
|
266 |
self.display(pprint=True) # print results
|
@@ -271,6 +274,10 @@ class Detections:
|
|
271 |
def save(self):
|
272 |
self.display(save=True) # save results
|
273 |
|
|
|
|
|
|
|
|
|
274 |
def __len__(self):
|
275 |
return self.n
|
276 |
|
|
|
1 |
# This file contains modules common to various models
|
2 |
|
3 |
import math
|
4 |
+
|
5 |
import numpy as np
|
6 |
import requests
|
7 |
import torch
|
|
|
241 |
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
242 |
self.n = len(self.pred)
|
243 |
|
244 |
+
def display(self, pprint=False, show=False, save=False, render=False):
|
245 |
colors = color_list()
|
246 |
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
|
247 |
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
|
|
|
249 |
for c in pred[:, -1].unique():
|
250 |
n = (pred[:, -1] == c).sum() # detections per class
|
251 |
str += f'{n} {self.names[int(c)]}s, ' # add to string
|
252 |
+
if show or save or render:
|
253 |
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
|
254 |
for *box, conf, cls in pred: # xyxy, confidence, class
|
255 |
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
|
256 |
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
|
257 |
+
if pprint:
|
258 |
+
print(str)
|
259 |
+
if show:
|
260 |
+
img.show(f'Image {i}') # show
|
261 |
if save:
|
262 |
f = f'results{i}.jpg'
|
263 |
str += f"saved to '{f}'"
|
264 |
img.save(f) # save
|
265 |
+
if render:
|
266 |
+
self.imgs[i] = np.asarray(img)
|
|
|
|
|
267 |
|
268 |
def print(self):
|
269 |
self.display(pprint=True) # print results
|
|
|
274 |
def save(self):
|
275 |
self.display(save=True) # save results
|
276 |
|
277 |
+
def render(self):
|
278 |
+
self.display(render=True) # render results
|
279 |
+
return self.imgs
|
280 |
+
|
281 |
def __len__(self):
|
282 |
return self.n
|
283 |
|
train.py
CHANGED
@@ -28,7 +28,7 @@ from utils.autoanchor import check_anchors
|
|
28 |
from utils.datasets import create_dataloader
|
29 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
30 |
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
31 |
-
check_requirements, print_mutation, set_logging, one_cycle
|
32 |
from utils.google_utils import attempt_download
|
33 |
from utils.loss import compute_loss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
@@ -44,7 +44,7 @@ except ImportError:
|
|
44 |
|
45 |
|
46 |
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
47 |
-
logger.info(
|
48 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
49 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
50 |
|
@@ -233,9 +233,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
233 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
234 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
235 |
scaler = amp.GradScaler(enabled=cuda)
|
236 |
-
logger.info('Image sizes
|
237 |
-
'Using
|
238 |
-
'
|
|
|
239 |
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
240 |
model.train()
|
241 |
|
|
|
28 |
from utils.datasets import create_dataloader
|
29 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
30 |
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
31 |
+
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
32 |
from utils.google_utils import attempt_download
|
33 |
from utils.loss import compute_loss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
|
|
44 |
|
45 |
|
46 |
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
47 |
+
logger.info(colorstr('blue', 'bold', 'Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
48 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
49 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
50 |
|
|
|
233 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
234 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
235 |
scaler = amp.GradScaler(enabled=cuda)
|
236 |
+
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
|
237 |
+
f'Using {dataloader.num_workers} dataloader workers\n'
|
238 |
+
f'Logging results to {save_dir}\n'
|
239 |
+
f'Starting training for {epochs} epochs...')
|
240 |
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
241 |
model.train()
|
242 |
|
utils/general.py
CHANGED
@@ -25,6 +25,7 @@ from utils.torch_utils import init_torch_seeds
|
|
25 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
26 |
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
27 |
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
|
|
28 |
|
29 |
|
30 |
def set_logging(rank=-1):
|
@@ -117,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
|
|
117 |
|
118 |
def colorstr(*input):
|
119 |
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
120 |
-
*prefix,
|
121 |
colors = {'black': '\033[30m', # basic colors
|
122 |
'red': '\033[31m',
|
123 |
'green': '\033[32m',
|
@@ -136,9 +137,9 @@ def colorstr(*input):
|
|
136 |
'bright_white': '\033[97m',
|
137 |
'end': '\033[0m', # misc
|
138 |
'bold': '\033[1m',
|
139 |
-
'
|
140 |
|
141 |
-
return ''.join(colors[x] for x in prefix) +
|
142 |
|
143 |
|
144 |
def labels_to_class_weights(labels, nc=80):
|
|
|
25 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
26 |
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
27 |
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
28 |
+
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
29 |
|
30 |
|
31 |
def set_logging(rank=-1):
|
|
|
118 |
|
119 |
def colorstr(*input):
|
120 |
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
121 |
+
*prefix, string = input # color arguments, string
|
122 |
colors = {'black': '\033[30m', # basic colors
|
123 |
'red': '\033[31m',
|
124 |
'green': '\033[32m',
|
|
|
137 |
'bright_white': '\033[97m',
|
138 |
'end': '\033[0m', # misc
|
139 |
'bold': '\033[1m',
|
140 |
+
'underline': '\033[4m'}
|
141 |
|
142 |
+
return ''.join(colors[x] for x in prefix) + f'{string}' + colors['end']
|
143 |
|
144 |
|
145 |
def labels_to_class_weights(labels, nc=80):
|