# YOLOv5 🚀 by Ultralytics, GPL-3.0 license """ Callback utils """ import threading class Callbacks: """ " Handles all registered callbacks for YOLOv5 Hooks """ def __init__(self): # Define the available callbacks self._callbacks = { "on_pretrain_routine_start": [], "on_pretrain_routine_end": [], "on_train_start": [], "on_train_epoch_start": [], "on_train_batch_start": [], "optimizer_step": [], "on_before_zero_grad": [], "on_train_batch_end": [], "on_train_epoch_end": [], "on_val_start": [], "on_val_batch_start": [], "on_val_image_end": [], "on_val_batch_end": [], "on_val_end": [], "on_fit_epoch_end": [], # fit = train + val "on_model_save": [], "on_train_end": [], "on_params_update": [], "teardown": [], } self.stop_training = False # set True to interrupt training def register_action(self, hook, name="", callback=None): """ Register a new action to a callback hook Args: hook: The callback hook name to register the action to name: The name of the action for later reference callback: The callback to fire """ assert ( hook in self._callbacks ), f"hook '{hook}' not found in callbacks {self._callbacks}" assert callable(callback), f"callback '{callback}' is not callable" self._callbacks[hook].append({"name": name, "callback": callback}) def get_registered_actions(self, hook=None): """ " Returns all the registered actions by callback hook Args: hook: The name of the hook to check, defaults to all """ return self._callbacks[hook] if hook else self._callbacks def run(self, hook, *args, thread=False, **kwargs): """ Loop through the registered actions and fire all callbacks on main thread Args: hook: The name of the hook to check, defaults to all args: Arguments to receive from YOLOv5 thread: (boolean) Run callbacks in daemon thread kwargs: Keyword Arguments to receive from YOLOv5 """ assert ( hook in self._callbacks ), f"hook '{hook}' not found in callbacks {self._callbacks}" for logger in self._callbacks[hook]: if thread: threading.Thread( target=logger["callback"], args=args, kwargs=kwargs, daemon=True, ).start() else: logger["callback"](*args, **kwargs)