File size: 2,666 Bytes
92894b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
"""Callback utils."""
import threading
class Callbacks:
"""" Handles all registered callbacks for YOLOv5 Hooks."""
def __init__(self):
# Define the available callbacks
self._callbacks = {
"on_pretrain_routine_start": [],
"on_pretrain_routine_end": [],
"on_train_start": [],
"on_train_epoch_start": [],
"on_train_batch_start": [],
"optimizer_step": [],
"on_before_zero_grad": [],
"on_train_batch_end": [],
"on_train_epoch_end": [],
"on_val_start": [],
"on_val_batch_start": [],
"on_val_image_end": [],
"on_val_batch_end": [],
"on_val_end": [],
"on_fit_epoch_end": [], # fit = train + val
"on_model_save": [],
"on_train_end": [],
"on_params_update": [],
"teardown": [],
}
self.stop_training = False # set True to interrupt training
def register_action(self, hook, name="", callback=None):
"""
Register a new action to a callback hook.
Args:
hook: The callback hook name to register the action to
name: The name of the action for later reference
callback: The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({"name": name, "callback": callback})
def get_registered_actions(self, hook=None):
"""
" Returns all the registered actions by callback hook.
Args:
hook: The name of the hook to check, defaults to all
"""
return self._callbacks[hook] if hook else self._callbacks
def run(self, hook, *args, thread=False, **kwargs):
"""
Loop through the registered actions and fire all callbacks on main thread.
Args:
hook: The name of the hook to check, defaults to all
args: Arguments to receive from YOLOv5
thread: (boolean) Run callbacks in daemon thread
kwargs: Keyword Arguments to receive from YOLOv5
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]:
if thread:
threading.Thread(target=logger["callback"], args=args, kwargs=kwargs, daemon=True).start()
else:
logger["callback"](*args, **kwargs)
|