vc1-base / attn_helper.py
sneha
change attn map appearance, simplify
46f48ca
import cv2
from PIL import Image
import numpy as np
import torch
import PIL
def overlay_attn(original_image,mask):
# Colormap and alpha for attention mask
# COLORMAP_OCEAN
# COLORMAP_OCEAN
colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 #0.85
# Resize mask to original image size
w, h = original_image.shape[0], original_image.shape[1]
mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis]
# Apply colormap to mask
cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
# Blend mask and original image
# grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
# grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
# alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0)
# alpha_blended = cmap
# Save image
final_im = Image.fromarray(alpha_blended)
# final_im = final_im.crop((0,0,250,250))
return final_im
class VITAttentionGradRollout:
'''
Expects timm ViT transformer model
Adapted from https://github.com/samiraabnar/attention_flow
'''
def __init__(self, model, head_fusion='min', discard_ratio=0):
self.model = model
self.head_fusion = head_fusion
self.discard_ratio = discard_ratio
self.attentions = {}
for idx, module in enumerate(list(model.blocks.children())):
module.attn.register_forward_hook(self.get_attention(f"attn{idx}"))
def get_attention(self, name):
def hook(module, input, output):
with torch.no_grad():
input = input[0]
B, N, C = input.shape
qkv = (
module.qkv(input)
.detach()
.reshape(B, N, 3, module.num_heads, C // module.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, _ = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * module.scale
attn = attn.softmax(dim=-1)
self.attentions[name] = attn
return hook
def get_attn_mask(self,k=0):
attn_key = "attn" + str()
result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device)
# result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device)
with torch.no_grad():
# for attention in self.attentions.values():
for k in range(11, len(self.attentions.keys())):
attention = self.attentions[f'attn{k}']
if self.head_fusion == "mean":
attention_heads_fused = attention.mean(axis=1)
elif self.head_fusion == "max":
attention_heads_fused = attention.max(axis=1)[0]
elif self.head_fusion == "min":
attention_heads_fused = attention.min(axis=1)[0]
else:
raise "Attention head fusion type Not supported"
# Drop the lowest attentions, but
# don't drop the class token
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
indices = indices[indices != 0]
flat[0, indices] = 0
I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
a = (attention_heads_fused + 1.0*I)/2
a = a / a.sum(dim=-1).unsqueeze(-1)
result = torch.matmul(a, result)
# Look at the total attention between the class token,
# and the image patches
mask = result[0, 0 , 1 :]
# In case of 224x224 image, this brings us from 196 to 14
width = int(mask.size(-1)**0.5)
mask = mask.reshape(width, width).detach().cpu().numpy()
mask = mask / np.max(mask)
return mask