Commit
•
b74929c
1
Parent(s):
d8f1883
Add `train.py` and `val.py` callbacks (#4220)
Browse files* added callbacks
* Update callbacks.py
* Update train.py
* Update val.py
* Fix CamlCase add staticmethod
* Refactor logger into callbacks
* Cleanup
* New callback on_val_image_end()
* Add curves and results images to TensorBoard
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
- train.py +19 -10
- utils/callbacks.py +176 -0
- utils/general.py +5 -0
- utils/loggers/__init__.py +24 -21
- utils/plots.py +1 -5
- val.py +5 -5
train.py
CHANGED
@@ -34,7 +34,7 @@ from utils.autoanchor import check_anchors
|
|
34 |
from utils.datasets import create_dataloader
|
35 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
36 |
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
37 |
-
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
38 |
from utils.downloads import attempt_download
|
39 |
from utils.loss import ComputeLoss
|
40 |
from utils.plots import plot_labels, plot_evolution
|
@@ -42,6 +42,7 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
|
|
42 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
43 |
from utils.metrics import fitness
|
44 |
from utils.loggers import Loggers
|
|
|
45 |
|
46 |
LOGGER = logging.getLogger(__name__)
|
47 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
@@ -52,6 +53,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
|
52 |
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
53 |
opt,
|
54 |
device,
|
|
|
55 |
):
|
56 |
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
57 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
@@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
77 |
|
78 |
# Loggers
|
79 |
if RANK in [-1, 0]:
|
80 |
-
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER)
|
81 |
if loggers.wandb:
|
82 |
data_dict = loggers.wandb.data_dict
|
83 |
if resume:
|
84 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
|
85 |
|
|
|
|
|
|
|
|
|
86 |
# Config
|
87 |
plots = not evolve # create plots
|
88 |
cuda = device.type != 'cpu'
|
@@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
215 |
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
216 |
# model._initialize_biases(cf.to(device))
|
217 |
if plots:
|
218 |
-
plot_labels(labels, names, save_dir
|
219 |
|
220 |
# Anchors
|
221 |
if not opt.noautoanchor:
|
222 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
223 |
model.half().float() # pre-reduce anchor precision
|
224 |
|
|
|
|
|
225 |
# DDP mode
|
226 |
if cuda and RANK != -1:
|
227 |
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
@@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
329 |
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
330 |
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
331 |
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
332 |
-
|
333 |
-
|
334 |
# end batch ------------------------------------------------------------------------------------------------
|
335 |
|
336 |
# Scheduler
|
@@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
339 |
|
340 |
if RANK in [-1, 0]:
|
341 |
# mAP
|
342 |
-
|
343 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
344 |
final_epoch = epoch + 1 == epochs
|
345 |
if not noval or final_epoch: # Calculate mAP
|
@@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
353 |
save_json=is_coco and final_epoch,
|
354 |
verbose=nc < 50 and final_epoch,
|
355 |
plots=plots and final_epoch,
|
356 |
-
|
357 |
compute_loss=compute_loss)
|
358 |
|
359 |
# Update best mAP
|
360 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
|
361 |
if fi > best_fitness:
|
362 |
best_fitness = fi
|
363 |
-
|
364 |
|
365 |
# Save model
|
366 |
if (not nosave) or (final_epoch and not evolve): # if save
|
@@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
377 |
if best_fitness == fi:
|
378 |
torch.save(ckpt, best)
|
379 |
del ckpt
|
380 |
-
|
381 |
|
382 |
# end epoch ----------------------------------------------------------------------------------------------------
|
383 |
# end training -----------------------------------------------------------------------------------------------------
|
@@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
400 |
for f in last, best:
|
401 |
if f.exists():
|
402 |
strip_optimizer(f) # strip optimizers
|
403 |
-
|
|
|
404 |
|
405 |
torch.cuda.empty_cache()
|
406 |
return results
|
@@ -448,6 +456,7 @@ def parse_opt(known=False):
|
|
448 |
|
449 |
|
450 |
def main(opt):
|
|
|
451 |
set_logging(RANK)
|
452 |
if RANK in [-1, 0]:
|
453 |
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
|
|
34 |
from utils.datasets import create_dataloader
|
35 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
36 |
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
37 |
+
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
|
38 |
from utils.downloads import attempt_download
|
39 |
from utils.loss import ComputeLoss
|
40 |
from utils.plots import plot_labels, plot_evolution
|
|
|
42 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
43 |
from utils.metrics import fitness
|
44 |
from utils.loggers import Loggers
|
45 |
+
from utils.callbacks import Callbacks
|
46 |
|
47 |
LOGGER = logging.getLogger(__name__)
|
48 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
|
|
53 |
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
54 |
opt,
|
55 |
device,
|
56 |
+
callbacks=Callbacks()
|
57 |
):
|
58 |
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
59 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
|
|
79 |
|
80 |
# Loggers
|
81 |
if RANK in [-1, 0]:
|
82 |
+
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
|
83 |
if loggers.wandb:
|
84 |
data_dict = loggers.wandb.data_dict
|
85 |
if resume:
|
86 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
|
87 |
|
88 |
+
# Register actions
|
89 |
+
for k in methods(loggers):
|
90 |
+
callbacks.register_action(k, callback=getattr(loggers, k))
|
91 |
+
|
92 |
# Config
|
93 |
plots = not evolve # create plots
|
94 |
cuda = device.type != 'cpu'
|
|
|
221 |
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
222 |
# model._initialize_biases(cf.to(device))
|
223 |
if plots:
|
224 |
+
plot_labels(labels, names, save_dir)
|
225 |
|
226 |
# Anchors
|
227 |
if not opt.noautoanchor:
|
228 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
229 |
model.half().float() # pre-reduce anchor precision
|
230 |
|
231 |
+
callbacks.on_pretrain_routine_end()
|
232 |
+
|
233 |
# DDP mode
|
234 |
if cuda and RANK != -1:
|
235 |
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
|
|
337 |
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
338 |
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
339 |
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
340 |
+
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
|
|
|
341 |
# end batch ------------------------------------------------------------------------------------------------
|
342 |
|
343 |
# Scheduler
|
|
|
346 |
|
347 |
if RANK in [-1, 0]:
|
348 |
# mAP
|
349 |
+
callbacks.on_train_epoch_end(epoch=epoch)
|
350 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
351 |
final_epoch = epoch + 1 == epochs
|
352 |
if not noval or final_epoch: # Calculate mAP
|
|
|
360 |
save_json=is_coco and final_epoch,
|
361 |
verbose=nc < 50 and final_epoch,
|
362 |
plots=plots and final_epoch,
|
363 |
+
callbacks=callbacks,
|
364 |
compute_loss=compute_loss)
|
365 |
|
366 |
# Update best mAP
|
367 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
|
368 |
if fi > best_fitness:
|
369 |
best_fitness = fi
|
370 |
+
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)
|
371 |
|
372 |
# Save model
|
373 |
if (not nosave) or (final_epoch and not evolve): # if save
|
|
|
384 |
if best_fitness == fi:
|
385 |
torch.save(ckpt, best)
|
386 |
del ckpt
|
387 |
+
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
388 |
|
389 |
# end epoch ----------------------------------------------------------------------------------------------------
|
390 |
# end training -----------------------------------------------------------------------------------------------------
|
|
|
407 |
for f in last, best:
|
408 |
if f.exists():
|
409 |
strip_optimizer(f) # strip optimizers
|
410 |
+
callbacks.on_train_end(last, best, plots, epoch)
|
411 |
+
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
412 |
|
413 |
torch.cuda.empty_cache()
|
414 |
return results
|
|
|
456 |
|
457 |
|
458 |
def main(opt):
|
459 |
+
# Checks
|
460 |
set_logging(RANK)
|
461 |
if RANK in [-1, 0]:
|
462 |
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
utils/callbacks.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
class Callbacks:
|
4 |
+
""""
|
5 |
+
Handles all registered callbacks for YOLOv5 Hooks
|
6 |
+
"""
|
7 |
+
|
8 |
+
_callbacks = {
|
9 |
+
'on_pretrain_routine_start': [],
|
10 |
+
'on_pretrain_routine_end': [],
|
11 |
+
|
12 |
+
'on_train_start': [],
|
13 |
+
'on_train_epoch_start': [],
|
14 |
+
'on_train_batch_start': [],
|
15 |
+
'optimizer_step': [],
|
16 |
+
'on_before_zero_grad': [],
|
17 |
+
'on_train_batch_end': [],
|
18 |
+
'on_train_epoch_end': [],
|
19 |
+
|
20 |
+
'on_val_start': [],
|
21 |
+
'on_val_batch_start': [],
|
22 |
+
'on_val_image_end': [],
|
23 |
+
'on_val_batch_end': [],
|
24 |
+
'on_val_end': [],
|
25 |
+
|
26 |
+
'on_fit_epoch_end': [], # fit = train + val
|
27 |
+
'on_model_save': [],
|
28 |
+
'on_train_end': [],
|
29 |
+
|
30 |
+
'teardown': [],
|
31 |
+
}
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
return
|
35 |
+
|
36 |
+
def register_action(self, hook, name='', callback=None):
|
37 |
+
"""
|
38 |
+
Register a new action to a callback hook
|
39 |
+
|
40 |
+
Args:
|
41 |
+
hook The callback hook name to register the action to
|
42 |
+
name The name of the action
|
43 |
+
callback The callback to fire
|
44 |
+
"""
|
45 |
+
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
46 |
+
assert callable(callback), f"callback '{callback}' is not callable"
|
47 |
+
self._callbacks[hook].append({'name': name, 'callback': callback})
|
48 |
+
|
49 |
+
def get_registered_actions(self, hook=None):
|
50 |
+
""""
|
51 |
+
Returns all the registered actions by callback hook
|
52 |
+
|
53 |
+
Args:
|
54 |
+
hook The name of the hook to check, defaults to all
|
55 |
+
"""
|
56 |
+
if hook:
|
57 |
+
return self._callbacks[hook]
|
58 |
+
else:
|
59 |
+
return self._callbacks
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def run_callbacks(register, *args, **kwargs):
|
63 |
+
"""
|
64 |
+
Loop through the registered actions and fire all callbacks
|
65 |
+
"""
|
66 |
+
for logger in register:
|
67 |
+
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
68 |
+
logger['callback'](*args, **kwargs)
|
69 |
+
|
70 |
+
def on_pretrain_routine_start(self, *args, **kwargs):
|
71 |
+
"""
|
72 |
+
Fires all registered callbacks at the start of each pretraining routine
|
73 |
+
"""
|
74 |
+
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
|
75 |
+
|
76 |
+
def on_pretrain_routine_end(self, *args, **kwargs):
|
77 |
+
"""
|
78 |
+
Fires all registered callbacks at the end of each pretraining routine
|
79 |
+
"""
|
80 |
+
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
|
81 |
+
|
82 |
+
def on_train_start(self, *args, **kwargs):
|
83 |
+
"""
|
84 |
+
Fires all registered callbacks at the start of each training
|
85 |
+
"""
|
86 |
+
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
|
87 |
+
|
88 |
+
def on_train_epoch_start(self, *args, **kwargs):
|
89 |
+
"""
|
90 |
+
Fires all registered callbacks at the start of each training epoch
|
91 |
+
"""
|
92 |
+
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
|
93 |
+
|
94 |
+
def on_train_batch_start(self, *args, **kwargs):
|
95 |
+
"""
|
96 |
+
Fires all registered callbacks at the start of each training batch
|
97 |
+
"""
|
98 |
+
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
|
99 |
+
|
100 |
+
def optimizer_step(self, *args, **kwargs):
|
101 |
+
"""
|
102 |
+
Fires all registered callbacks on each optimizer step
|
103 |
+
"""
|
104 |
+
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
|
105 |
+
|
106 |
+
def on_before_zero_grad(self, *args, **kwargs):
|
107 |
+
"""
|
108 |
+
Fires all registered callbacks before zero grad
|
109 |
+
"""
|
110 |
+
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
|
111 |
+
|
112 |
+
def on_train_batch_end(self, *args, **kwargs):
|
113 |
+
"""
|
114 |
+
Fires all registered callbacks at the end of each training batch
|
115 |
+
"""
|
116 |
+
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
|
117 |
+
|
118 |
+
def on_train_epoch_end(self, *args, **kwargs):
|
119 |
+
"""
|
120 |
+
Fires all registered callbacks at the end of each training epoch
|
121 |
+
"""
|
122 |
+
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
|
123 |
+
|
124 |
+
def on_val_start(self, *args, **kwargs):
|
125 |
+
"""
|
126 |
+
Fires all registered callbacks at the start of the validation
|
127 |
+
"""
|
128 |
+
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
|
129 |
+
|
130 |
+
def on_val_batch_start(self, *args, **kwargs):
|
131 |
+
"""
|
132 |
+
Fires all registered callbacks at the start of each validation batch
|
133 |
+
"""
|
134 |
+
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
|
135 |
+
|
136 |
+
def on_val_image_end(self, *args, **kwargs):
|
137 |
+
"""
|
138 |
+
Fires all registered callbacks at the end of each val image
|
139 |
+
"""
|
140 |
+
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
|
141 |
+
|
142 |
+
def on_val_batch_end(self, *args, **kwargs):
|
143 |
+
"""
|
144 |
+
Fires all registered callbacks at the end of each validation batch
|
145 |
+
"""
|
146 |
+
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
|
147 |
+
|
148 |
+
def on_val_end(self, *args, **kwargs):
|
149 |
+
"""
|
150 |
+
Fires all registered callbacks at the end of the validation
|
151 |
+
"""
|
152 |
+
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
|
153 |
+
|
154 |
+
def on_fit_epoch_end(self, *args, **kwargs):
|
155 |
+
"""
|
156 |
+
Fires all registered callbacks at the end of each fit (train+val) epoch
|
157 |
+
"""
|
158 |
+
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
|
159 |
+
|
160 |
+
def on_model_save(self, *args, **kwargs):
|
161 |
+
"""
|
162 |
+
Fires all registered callbacks after each model save
|
163 |
+
"""
|
164 |
+
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
|
165 |
+
|
166 |
+
def on_train_end(self, *args, **kwargs):
|
167 |
+
"""
|
168 |
+
Fires all registered callbacks at the end of training
|
169 |
+
"""
|
170 |
+
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
|
171 |
+
|
172 |
+
def teardown(self, *args, **kwargs):
|
173 |
+
"""
|
174 |
+
Fires all registered callbacks before teardown
|
175 |
+
"""
|
176 |
+
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
|
utils/general.py
CHANGED
@@ -67,6 +67,11 @@ def try_except(func):
|
|
67 |
return handler
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
70 |
def set_logging(rank=-1, verbose=True):
|
71 |
logging.basicConfig(
|
72 |
format="%(message)s",
|
|
|
67 |
return handler
|
68 |
|
69 |
|
70 |
+
def methods(instance):
|
71 |
+
# Get class/instance methods
|
72 |
+
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
73 |
+
|
74 |
+
|
75 |
def set_logging(rank=-1, verbose=True):
|
76 |
logging.basicConfig(
|
77 |
format="%(message)s",
|
utils/loggers/__init__.py
CHANGED
@@ -29,10 +29,12 @@ class Loggers():
|
|
29 |
self.hyp = hyp
|
30 |
self.logger = logger # for printing results to console
|
31 |
self.include = include
|
|
|
|
|
|
|
|
|
32 |
for k in LOGGERS:
|
33 |
setattr(self, k, None) # init empty logger dictionary
|
34 |
-
|
35 |
-
def start(self):
|
36 |
self.csv = True # always log to csv
|
37 |
|
38 |
# Message
|
@@ -57,7 +59,11 @@ class Loggers():
|
|
57 |
else:
|
58 |
self.wandb = None
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
63 |
# Callback runs on train batch end
|
@@ -78,8 +84,8 @@ class Loggers():
|
|
78 |
if self.wandb:
|
79 |
self.wandb.current_epoch = epoch + 1
|
80 |
|
81 |
-
def
|
82 |
-
# Callback runs on
|
83 |
if self.wandb:
|
84 |
self.wandb.val_one_image(pred, predn, path, names, im)
|
85 |
|
@@ -89,25 +95,20 @@ class Loggers():
|
|
89 |
files = sorted(self.save_dir.glob('val*.jpg'))
|
90 |
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
91 |
|
92 |
-
def
|
93 |
-
# Callback runs
|
94 |
vals = list(mloss) + list(results) + lr
|
95 |
-
|
96 |
-
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
|
97 |
-
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
98 |
-
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
99 |
-
x = {k: v for k, v in zip(keys, vals)} # dict
|
100 |
-
|
101 |
if self.csv:
|
102 |
file = self.save_dir / 'results.csv'
|
103 |
n = len(x) + 1 # number of cols
|
104 |
-
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
|
105 |
with open(file, 'a') as f:
|
106 |
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
|
107 |
|
108 |
if self.tb:
|
109 |
for k, v in x.items():
|
110 |
-
self.tb.add_scalar(k, v, epoch)
|
111 |
|
112 |
if self.wandb:
|
113 |
self.wandb.log(x)
|
@@ -119,20 +120,22 @@ class Loggers():
|
|
119 |
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
120 |
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
121 |
|
122 |
-
def on_train_end(self, last, best, plots):
|
123 |
# Callback runs on training end
|
124 |
if plots:
|
125 |
plot_results(dir=self.save_dir) # save results.png
|
126 |
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
127 |
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
if self.wandb:
|
129 |
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
130 |
wandb.log_artifact(str(best if best.exists() else last), type='model',
|
131 |
name='run_' + self.wandb.wandb_run.id + '_model',
|
132 |
aliases=['latest', 'best', 'stripped'])
|
133 |
self.wandb.finish_run()
|
134 |
-
|
135 |
-
def log_images(self, paths):
|
136 |
-
# Log images
|
137 |
-
if self.wandb:
|
138 |
-
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
|
|
29 |
self.hyp = hyp
|
30 |
self.logger = logger # for printing results to console
|
31 |
self.include = include
|
32 |
+
self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
|
33 |
+
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
|
34 |
+
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
35 |
+
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
36 |
for k in LOGGERS:
|
37 |
setattr(self, k, None) # init empty logger dictionary
|
|
|
|
|
38 |
self.csv = True # always log to csv
|
39 |
|
40 |
# Message
|
|
|
59 |
else:
|
60 |
self.wandb = None
|
61 |
|
62 |
+
def on_pretrain_routine_end(self):
|
63 |
+
# Callback runs on pre-train routine end
|
64 |
+
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
65 |
+
if self.wandb:
|
66 |
+
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
67 |
|
68 |
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
69 |
# Callback runs on train batch end
|
|
|
84 |
if self.wandb:
|
85 |
self.wandb.current_epoch = epoch + 1
|
86 |
|
87 |
+
def on_val_image_end(self, pred, predn, path, names, im):
|
88 |
+
# Callback runs on val image end
|
89 |
if self.wandb:
|
90 |
self.wandb.val_one_image(pred, predn, path, names, im)
|
91 |
|
|
|
95 |
files = sorted(self.save_dir.glob('val*.jpg'))
|
96 |
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
97 |
|
98 |
+
def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi):
|
99 |
+
# Callback runs at the end of each fit (train+val) epoch
|
100 |
vals = list(mloss) + list(results) + lr
|
101 |
+
x = {k: v for k, v in zip(self.keys, vals)} # dict
|
|
|
|
|
|
|
|
|
|
|
102 |
if self.csv:
|
103 |
file = self.save_dir / 'results.csv'
|
104 |
n = len(x) + 1 # number of cols
|
105 |
+
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
|
106 |
with open(file, 'a') as f:
|
107 |
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
|
108 |
|
109 |
if self.tb:
|
110 |
for k, v in x.items():
|
111 |
+
self.tb.add_scalar(k, v, epoch)
|
112 |
|
113 |
if self.wandb:
|
114 |
self.wandb.log(x)
|
|
|
120 |
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
121 |
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
122 |
|
123 |
+
def on_train_end(self, last, best, plots, epoch):
|
124 |
# Callback runs on training end
|
125 |
if plots:
|
126 |
plot_results(dir=self.save_dir) # save results.png
|
127 |
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
128 |
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
129 |
+
|
130 |
+
if self.tb:
|
131 |
+
from PIL import Image
|
132 |
+
import numpy as np
|
133 |
+
for f in files:
|
134 |
+
self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC')
|
135 |
+
|
136 |
if self.wandb:
|
137 |
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
138 |
wandb.log_artifact(str(best if best.exists() else last), type='model',
|
139 |
name='run_' + self.wandb.wandb_run.id + '_model',
|
140 |
aliases=['latest', 'best', 'stripped'])
|
141 |
self.wandb.finish_run()
|
|
|
|
|
|
|
|
|
|
utils/plots.py
CHANGED
@@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
|
|
281 |
plt.savefig(str(Path(path).name) + '.png', dpi=300)
|
282 |
|
283 |
|
284 |
-
def plot_labels(labels, names=(), save_dir=Path('')
|
285 |
# plot dataset labels
|
286 |
print('Plotting labels... ')
|
287 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
@@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
|
324 |
matplotlib.use('Agg')
|
325 |
plt.close()
|
326 |
|
327 |
-
# loggers
|
328 |
-
if loggers:
|
329 |
-
loggers.log_images(save_dir.glob('*labels*.jpg'))
|
330 |
-
|
331 |
|
332 |
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
|
333 |
# Plot hyperparameter evolution results in evolve.txt
|
|
|
281 |
plt.savefig(str(Path(path).name) + '.png', dpi=300)
|
282 |
|
283 |
|
284 |
+
def plot_labels(labels, names=(), save_dir=Path('')):
|
285 |
# plot dataset labels
|
286 |
print('Plotting labels... ')
|
287 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
|
|
324 |
matplotlib.use('Agg')
|
325 |
plt.close()
|
326 |
|
|
|
|
|
|
|
|
|
327 |
|
328 |
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
|
329 |
# Plot hyperparameter evolution results in evolve.txt
|
val.py
CHANGED
@@ -25,7 +25,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
|
|
25 |
from utils.metrics import ap_per_class, ConfusionMatrix
|
26 |
from utils.plots import plot_images, output_to_target, plot_study_txt
|
27 |
from utils.torch_utils import select_device, time_sync
|
28 |
-
from utils.
|
29 |
|
30 |
|
31 |
def save_one_txt(predn, save_conf, shape, file):
|
@@ -97,7 +97,7 @@ def run(data,
|
|
97 |
dataloader=None,
|
98 |
save_dir=Path(''),
|
99 |
plots=True,
|
100 |
-
|
101 |
compute_loss=None,
|
102 |
):
|
103 |
# Initialize/load model and set device
|
@@ -213,7 +213,7 @@ def run(data,
|
|
213 |
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
214 |
if save_json:
|
215 |
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
216 |
-
|
217 |
|
218 |
# Plot images
|
219 |
if plots and batch_i < 3:
|
@@ -250,7 +250,7 @@ def run(data,
|
|
250 |
# Plots
|
251 |
if plots:
|
252 |
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
253 |
-
|
254 |
|
255 |
# Save JSON
|
256 |
if save_json and len(jdict):
|
@@ -282,7 +282,7 @@ def run(data,
|
|
282 |
model.float() # for training
|
283 |
if not training:
|
284 |
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
285 |
-
print(f"Results saved to {save_dir}{s}")
|
286 |
maps = np.zeros(nc) + map
|
287 |
for i, c in enumerate(ap_class):
|
288 |
maps[c] = ap[i]
|
|
|
25 |
from utils.metrics import ap_per_class, ConfusionMatrix
|
26 |
from utils.plots import plot_images, output_to_target, plot_study_txt
|
27 |
from utils.torch_utils import select_device, time_sync
|
28 |
+
from utils.callbacks import Callbacks
|
29 |
|
30 |
|
31 |
def save_one_txt(predn, save_conf, shape, file):
|
|
|
97 |
dataloader=None,
|
98 |
save_dir=Path(''),
|
99 |
plots=True,
|
100 |
+
callbacks=Callbacks(),
|
101 |
compute_loss=None,
|
102 |
):
|
103 |
# Initialize/load model and set device
|
|
|
213 |
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
214 |
if save_json:
|
215 |
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
216 |
+
callbacks.on_val_image_end(pred, predn, path, names, img[si])
|
217 |
|
218 |
# Plot images
|
219 |
if plots and batch_i < 3:
|
|
|
250 |
# Plots
|
251 |
if plots:
|
252 |
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
253 |
+
callbacks.on_val_end()
|
254 |
|
255 |
# Save JSON
|
256 |
if save_json and len(jdict):
|
|
|
282 |
model.float() # for training
|
283 |
if not training:
|
284 |
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
285 |
+
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
|
286 |
maps = np.zeros(nc) + map
|
287 |
for i, c in enumerate(ap_class):
|
288 |
maps[c] = ap[i]
|