import cv2 from PIL import Image import numpy as np import torch import PIL def overlay_attn(original_image,mask): # Colormap and alpha for attention mask # COLORMAP_OCEAN # COLORMAP_OCEAN colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 #0.85 # Resize mask to original image size w, h = original_image.shape[0], original_image.shape[1] mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis] # Apply colormap to mask cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn) # Blend mask and original image # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY) # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB) # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0) alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0) # alpha_blended = cmap # Save image final_im = Image.fromarray(alpha_blended) # final_im = final_im.crop((0,0,250,250)) return final_im class VITAttentionGradRollout: ''' Expects timm ViT transformer model Adapted from https://github.com/samiraabnar/attention_flow ''' def __init__(self, model, head_fusion='min', discard_ratio=0): self.model = model self.head_fusion = head_fusion self.discard_ratio = discard_ratio self.attentions = {} for idx, module in enumerate(list(model.blocks.children())): module.attn.register_forward_hook(self.get_attention(f"attn{idx}")) def get_attention(self, name): def hook(module, input, output): with torch.no_grad(): input = input[0] B, N, C = input.shape qkv = ( module.qkv(input) .detach() .reshape(B, N, 3, module.num_heads, C // module.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, _ = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * module.scale attn = attn.softmax(dim=-1) self.attentions[name] = attn return hook def get_attn_mask(self,k=0): attn_key = "attn" + str() result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device) # result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device) with torch.no_grad(): # for attention in self.attentions.values(): for k in range(11, len(self.attentions.keys())): attention = self.attentions[f'attn{k}'] if self.head_fusion == "mean": attention_heads_fused = attention.mean(axis=1) elif self.head_fusion == "max": attention_heads_fused = attention.max(axis=1)[0] elif self.head_fusion == "min": attention_heads_fused = attention.min(axis=1)[0] else: raise "Attention head fusion type Not supported" # Drop the lowest attentions, but # don't drop the class token flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False) indices = indices[indices != 0] flat[0, indices] = 0 I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device) a = (attention_heads_fused + 1.0*I)/2 a = a / a.sum(dim=-1).unsqueeze(-1) result = torch.matmul(a, result) # Look at the total attention between the class token, # and the image patches mask = result[0, 0 , 1 :] # In case of 224x224 image, this brings us from 196 to 14 width = int(mask.size(-1)**0.5) mask = mask.reshape(width, width).detach().cpu().numpy() mask = mask / np.max(mask) return mask