| | |
| | |
| | |
| | import logging |
| | import torch |
| | from .functions import REG_FUNCTION_MAP |
| |
|
| |
|
| | |
| | |
| | |
| | class HookMonitor: |
| | """ |
| | Monitors forward activations and backward gradients of a PyTorch model by |
| | registering hooks on all its submodules. The monitor computes per-layer |
| | statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward |
| | and backward passes, and provides normalized results at the end. |
| | |
| | This class is designed to be lightweight, safe (uses no_grad for activation |
| | hooks), and usable as a context manager to automate attachment and cleanup |
| | of hooks. |
| | |
| | ---------------------------------------- |
| | Core Behavior |
| | ---------------------------------------- |
| | - During the forward pass: |
| | • A forward hook receives (module, input, output). |
| | • The activation tensor is detached and cast to float. |
| | • For each registered metric in REG_FUNCTION_MAP, if its watcher flag |
| | is enabled, the metric is computed and accumulated. |
| | • A gradient hook is registered on the output tensor so that gradient |
| | statistics can also be collected during backpropagation. |
| | |
| | - During backpropagation: |
| | • The gradient hook receives the gradient tensor for the activation. |
| | • Any metric marked as `grad_<metric>` in the watcher dictionary will be |
| | applied to the gradient tensor and accumulated. |
| | |
| | - Statistics: |
| | • For each metric, the class tracks both the accumulated value and a |
| | "/valid/" counter. |
| | • `get_stats()` returns normalized statistics (sum / valid_count) for |
| | each metric per layer. |
| | |
| | ---------------------------------------- |
| | Parameters |
| | ---------------------------------------- |
| | model : torch.nn.Module |
| | The model whose modules will be monitored. All submodules returned by |
| | `model.named_modules()` will receive a forward hook. |
| | |
| | watcher : dict |
| | A dictionary mapping metric names to boolean flags. Keys must match the |
| | names used in `REG_FUNCTION_MAP`. Example: |
| | { |
| | "mean": True, |
| | "std": True, |
| | "grad_mean": True |
| | } |
| | |
| | Metrics not enabled here will not be computed. |
| | |
| | logger : logging.Logger |
| | A Logger used to report errors, debugging information, and warnings. |
| | |
| | ---------------------------------------- |
| | Attributes |
| | ---------------------------------------- |
| | stats : dict |
| | Nested dictionary storing accumulated statistics per layer. Normalized |
| | results are returned by `get_stats()`. |
| | |
| | handles : list |
| | A List of hook handles returned by `register_forward_hook`. These are |
| | stored to later remove all hooks safely. |
| | |
| | ---------------------------------------- |
| | Usage Example |
| | ---------------------------------------- |
| | >>> model: torch.nn.Module |
| | >>> watcher: dict[str, bool] |
| | >>> logger: logging.Logger |
| | >>> x: torch.Tensor |
| | >>> loss: torch.nn.Module # Loss |
| | |
| | >>> monitor = HookMonitor(model, watcher, logger) |
| | >>> monitor.attach() |
| | >>> output = model(x) |
| | >>> loss.backward() |
| | >>> stats = monitor.get_stats() |
| | >>> monitor.remove() |
| | |
| | Or using a context manager: |
| | |
| | >>> with HookMonitor(model, watcher, logger) as monitor: |
| | ... output = model(x) |
| | ... loss.backward() |
| | >>> stats = monitor.get_stats() |
| | |
| | ---------------------------------------- |
| | Notes |
| | ---------------------------------------- |
| | - The gradient hook is attached to the activation tensor (module output), |
| | not to model parameters. |
| | - No gradients are tracked during forward hooks thanks to @torch.no_grad(). |
| | - The monitor does not interfere with the training process: it only reads |
| | activations and gradients. |
| | - Missing '/valid/' counters trigger an error log and skip normalization for |
| | that metric. |
| | |
| | """ |
| | def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger): |
| | """ |
| | Initialize a HookMonitor instance to track activation and gradient |
| | statistics across all modules of a PyTorch model. |
| | |
| | This constructor does not attach any hooks yet; it simply stores the |
| | monitoring configuration. Hooks are registered only when `attach()` or |
| | the context manager (`with HookMonitor(...)`) is used. |
| | |
| | Parameters |
| | ---------- |
| | model : torch.nn.Module |
| | The model whose internal modules will be monitored. Every submodule |
| | returned by `model.named_modules()` will receive a forward hook. |
| | |
| | watcher : dict |
| | Dictionary of boolean flags controlling which statistics should be |
| | computed. Keys must match the names in `REG_FUNCTION_MAP`. |
| | Example: |
| | { |
| | "mean": True, |
| | "std": False, |
| | "grad_mean": True |
| | } |
| | |
| | Any metric not enabled here will not be computed during execution. |
| | |
| | logger : logging.Logger |
| | Logging instance used for reporting errors, debug messages and |
| | warnings during monitoring operations. |
| | |
| | Attributes Initialized |
| | ---------------------- |
| | model : torch.nn.Module |
| | Stored reference to the monitored model. |
| | |
| | watcher : dict |
| | The watcher configuration controlling metric activation. |
| | |
| | stats : dict |
| | Internal dictionary used to accumulate statistics across all layers. |
| | |
| | handles : list |
| | A List of hook handles created when calling `.attach()`. Each handle |
| | is later used to safely remove hooks with `.remove()`. |
| | |
| | Notes |
| | ----- |
| | - No hooks are installed at construction time. |
| | - The monitor becomes active only after calling `.attach()` or entering |
| | a `with` block. |
| | """ |
| | self.logger: logging.Logger = logger |
| | self.model: torch.nn.Module = model |
| | self.watcher: dict = watcher |
| | self.stats: dict = dict() |
| | self.handles: list = list() |
| |
|
| | def _build_hook(self, name): |
| |
|
| | @torch.no_grad() |
| | def hook(*args): |
| | _, _, act = args |
| |
|
| | if torch.is_tensor(act): |
| | act_detached = act.detach().float() |
| | s = self.stats.setdefault(name, {}) |
| |
|
| | |
| | for function_name, compute_function in REG_FUNCTION_MAP.items(): |
| | if self.watcher.get(function_name, False) and not function_name.startswith('grad_'): |
| | value = compute_function(act_detached, ...) |
| | if value is not None: |
| | s[function_name] = s.get(function_name, 0.0) + value |
| | s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1 |
| |
|
| | |
| | def grad_hook(grad): |
| | gd = grad.detach().float() |
| | |
| | for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items(): |
| | if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'): |
| | gd_function_name = 'grad_' + gd_function_name |
| | gd_value = gd_compute_function(gd, ...) |
| | if gd_value is not None: |
| | s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value |
| | s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1 |
| |
|
| | if act.requires_grad: |
| | act.register_hook(grad_hook) |
| |
|
| | return hook |
| |
|
| | def get_stats(self) -> dict: |
| | """ |
| | Get the statistics of the hooks. |
| | :return: A dictionary with the statistics. |
| | """ |
| | stats = dict() |
| | for layer_name, layer_stats in self.stats.items(): |
| | sub_stats = dict() |
| | for key, item in layer_stats.items(): |
| | if '/valid/' not in key: |
| | if key + '/valid/' in layer_stats: |
| | sub_stats[key] = item / layer_stats[key + '/valid/'] |
| | else: |
| | self.logger.error(f"Key {key} has no valid count, skipping normalization.") |
| | sub_stats[key] = item |
| | stats[layer_name] = sub_stats |
| | return stats |
| |
|
| | def attach(self): |
| | """ |
| | Registers all the hooks in the model. |
| | :return: The object. |
| | """ |
| | for name, module in self.model.named_modules(): |
| | h = module.register_forward_hook(self._build_hook(name)) |
| | self.handles.append(h) |
| | return self |
| |
|
| | def clear(self): |
| | """ |
| | Clear stats' dictionary. |
| | :return: Nothing |
| | """ |
| | self.stats.clear() |
| |
|
| | def remove(self): |
| | """ |
| | Remove all the hooks from the model. |
| | :return: Nothing. |
| | """ |
| | for h in self.handles: |
| | h.remove() |
| | self.handles.clear() |
| |
|
| | def __enter__(self): |
| | self.logger.debug("[Hooks] Attaching HookMonitor...") |
| | return self.attach() |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | self.logger.debug("[Hooks] Removing HookMonitor...") |
| | self.remove() |
| |
|
| | |
| | |
| | |
| |
|