class ActivationsAndGradients: """ Class for extracting activations and registering gradients from targetted intermediate layers """ def __init__(self, model, target_layers, reshape_transform): self.model = model self.gradients = [] self.activations = [] self.reshape_transform = reshape_transform self.handles = [] for target_layer in target_layers: self.handles.append( target_layer.register_forward_hook(self.save_activation)) # Because of https://github.com/pytorch/pytorch/issues/61519, # we don't use backward hook to record gradients. self.handles.append( target_layer.register_forward_hook(self.save_gradient)) def save_activation(self, module, input, output): activation = output if self.reshape_transform is not None: activation = self.reshape_transform(activation) self.activations.append(activation.cpu().detach()) def save_gradient(self, module, input, output): if not hasattr(output, "requires_grad") or not output.requires_grad: # You can only register hooks on tensor requires grad. return # Gradients are computed in reverse order def _store_grad(grad): if self.reshape_transform is not None: grad = self.reshape_transform(grad) self.gradients = [grad.cpu().detach()] + self.gradients output.register_hook(_store_grad) def __call__(self, x): self.gradients = [] self.activations = [] return self.model(x) def release(self): for handle in self.handles: handle.remove()