import torch import torch.nn as nn import torch.nn.functional as F __all__ = [ "forward_hook", "Clone", "Add", "Cat", "ReLU", "GELU", "Dropout", "BatchNorm2d", "Linear", "MaxPool2d", "AdaptiveAvgPool2d", "AvgPool2d", "Conv2d", "Sequential", "safe_divide", "einsum", "Softmax", "IndexSelect", "LayerNorm", "AddEye", ] def safe_divide(a, b): den = b.clamp(min=1e-9) + b.clamp(max=1e-9) den = den + den.eq(0).type(den.type()) * 1e-9 return a / den * b.ne(0).type(b.type()) def forward_hook(self, input, output): if type(input[0]) in (list, tuple): self.X = [] for i in input[0]: x = i.detach() x.requires_grad = True self.X.append(x) else: self.X = input[0].detach() self.X.requires_grad = True self.Y = output def backward_hook(self, grad_input, grad_output): self.grad_input = grad_input self.grad_output = grad_output class RelProp(nn.Module): def __init__(self): super(RelProp, self).__init__() # if not self.training: self.register_forward_hook(forward_hook) def gradprop(self, Z, X, S): C = torch.autograd.grad(Z, X, S, retain_graph=True) return C def relprop(self, R, alpha): return R class RelPropSimple(RelProp): def relprop(self, R, alpha): Z = self.forward(self.X) S = safe_divide(R, Z) C = self.gradprop(Z, self.X, S) if torch.is_tensor(self.X) == False: outputs = [] outputs.append(self.X[0] * C[0]) outputs.append(self.X[1] * C[1]) else: outputs = self.X * (C[0]) return outputs class AddEye(RelPropSimple): # input of shape B, C, seq_len, seq_len def forward(self, input): return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) class ReLU(nn.ReLU, RelProp): pass class GELU(nn.GELU, RelProp): pass class Softmax(nn.Softmax, RelProp): pass class LayerNorm(nn.LayerNorm, RelProp): pass class Dropout(nn.Dropout, RelProp): pass class MaxPool2d(nn.MaxPool2d, RelPropSimple): pass class LayerNorm(nn.LayerNorm, RelProp): pass class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): pass class AvgPool2d(nn.AvgPool2d, RelPropSimple): pass class Add(RelPropSimple): def forward(self, inputs): return torch.add(*inputs) class einsum(RelPropSimple): def __init__(self, equation): super().__init__() self.equation = equation def forward(self, *operands): return torch.einsum(self.equation, *operands) class IndexSelect(RelProp): def forward(self, inputs, dim, indices): self.__setattr__("dim", dim) self.__setattr__("indices", indices) return torch.index_select(inputs, dim, indices) def relprop(self, R, alpha): Z = self.forward(self.X, self.dim, self.indices) S = safe_divide(R, Z) C = self.gradprop(Z, self.X, S) if torch.is_tensor(self.X) == False: outputs = [] outputs.append(self.X[0] * C[0]) outputs.append(self.X[1] * C[1]) else: outputs = self.X * (C[0]) return outputs class Clone(RelProp): def forward(self, input, num): self.__setattr__("num", num) outputs = [] for _ in range(num): outputs.append(input) return outputs def relprop(self, R, alpha): Z = [] for _ in range(self.num): Z.append(self.X) S = [safe_divide(r, z) for r, z in zip(R, Z)] C = self.gradprop(Z, self.X, S)[0] R = self.X * C return R class Cat(RelProp): def forward(self, inputs, dim): self.__setattr__("dim", dim) return torch.cat(inputs, dim) def relprop(self, R, alpha): Z = self.forward(self.X, self.dim) S = safe_divide(R, Z) C = self.gradprop(Z, self.X, S) outputs = [] for x, c in zip(self.X, C): outputs.append(x * c) return outputs class Sequential(nn.Sequential): def relprop(self, R, alpha): for m in reversed(self._modules.values()): R = m.relprop(R, alpha) return R class BatchNorm2d(nn.BatchNorm2d, RelProp): def relprop(self, R, alpha): X = self.X beta = 1 - alpha weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( ( self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps ).pow(0.5) ) Z = X * weight + 1e-9 S = R / Z Ca = S * weight R = self.X * (Ca) return R class Linear(nn.Linear, RelProp): def relprop(self, R, alpha): beta = alpha - 1 pw = torch.clamp(self.weight, min=0) nw = torch.clamp(self.weight, max=0) px = torch.clamp(self.X, min=0) nx = torch.clamp(self.X, max=0) def f(w1, w2, x1, x2): Z1 = F.linear(x1, w1) Z2 = F.linear(x2, w2) S1 = safe_divide(R, Z1) S2 = safe_divide(R, Z2) C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0] C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0] return C1 + C2 activator_relevances = f(pw, nw, px, nx) inhibitor_relevances = f(nw, pw, px, nx) R = alpha * activator_relevances - beta * inhibitor_relevances return R class Conv2d(nn.Conv2d, RelProp): def gradprop2(self, DY, weight): Z = self.forward(self.X) output_padding = self.X.size()[2] - ( (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0] ) return F.conv_transpose2d( DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding, ) def relprop(self, R, alpha): if self.X.shape[1] == 3: pw = torch.clamp(self.weight, min=0) nw = torch.clamp(self.weight, max=0) X = self.X L = ( self.X * 0 + torch.min( torch.min( torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True )[0], dim=3, keepdim=True, )[0] ) H = ( self.X * 0 + torch.max( torch.max( torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True )[0], dim=3, keepdim=True, )[0] ) Za = ( torch.conv2d( X, self.weight, bias=None, stride=self.stride, padding=self.padding ) - torch.conv2d( L, pw, bias=None, stride=self.stride, padding=self.padding ) - torch.conv2d( H, nw, bias=None, stride=self.stride, padding=self.padding ) + 1e-9 ) S = R / Za C = ( X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) ) R = C else: beta = alpha - 1 pw = torch.clamp(self.weight, min=0) nw = torch.clamp(self.weight, max=0) px = torch.clamp(self.X, min=0) nx = torch.clamp(self.X, max=0) def f(w1, w2, x1, x2): Z1 = F.conv2d( x1, w1, bias=None, stride=self.stride, padding=self.padding ) Z2 = F.conv2d( x2, w2, bias=None, stride=self.stride, padding=self.padding ) S1 = safe_divide(R, Z1) S2 = safe_divide(R, Z2) C1 = x1 * self.gradprop(Z1, x1, S1)[0] C2 = x2 * self.gradprop(Z2, x2, S2)[0] return C1 + C2 activator_relevances = f(pw, nw, px, nx) inhibitor_relevances = f(nw, pw, px, nx) R = alpha * activator_relevances - beta * inhibitor_relevances return R