Spaces:
Sleeping
Sleeping
| class ActivationsAndGradients: | |
| """ Class for extracting activations and | |
| registering gradients from targetted intermediate layers """ | |
| def __init__(self, model, target_layers, reshape_transform): | |
| self.model = model | |
| self.gradients = [] | |
| self.activations = [] | |
| self.reshape_transform = reshape_transform | |
| self.handles = [] | |
| for target_layer in target_layers: | |
| self.handles.append( | |
| target_layer.register_forward_hook(self.save_activation)) | |
| # Because of https://github.com/pytorch/pytorch/issues/61519, | |
| # we don't use backward hook to record gradients. | |
| self.handles.append( | |
| target_layer.register_forward_hook(self.save_gradient)) | |
| def save_activation(self, module, input, output): | |
| activation = output | |
| if self.reshape_transform is not None: | |
| activation = self.reshape_transform(activation) | |
| self.activations.append(activation.cpu().detach()) | |
| def save_gradient(self, module, input, output): | |
| if not hasattr(output, "requires_grad") or not output.requires_grad: | |
| # You can only register hooks on tensor requires grad. | |
| return | |
| # Gradients are computed in reverse order | |
| def _store_grad(grad): | |
| if self.reshape_transform is not None: | |
| grad = self.reshape_transform(grad) | |
| self.gradients = [grad.cpu().detach()] + self.gradients | |
| output.register_hook(_store_grad) | |
| def __call__(self, x): | |
| self.gradients = [] | |
| self.activations = [] | |
| return self.model(x) | |
| def release(self): | |
| for handle in self.handles: | |
| handle.remove() | |