GEM / gem /gem_wrapper.py
WalidBouss's picture
Initial commit :tada:
be1ec96
raw
history blame
5.95 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip.transformer import VisionTransformer
from .gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward
class GEMWrapper(nn.Module):
def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None):
super(GEMWrapper, self).__init__()
self.model = model
self.tokenizer = tokenizer
self.depth = depth
self.ss_attn_iter = ss_attn_iter
self.ss_attn_temp = ss_attn_temp
self.patch_size = self.model.visual.patch_size[0]
self.apply_gem()
def apply_gem(self):
for i in range(1, self.depth):
# Extract info from the original ViT
num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads
dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads)
qkv_bias = True
# Init the self-self attention layer
ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp)
# Copy necessary weights
ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone()
ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone()
ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone()
ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone()
# Swap the original Attention with our SelfSelfAttention
self.model.visual.transformer.resblocks[-i].attn = ss_attn
# Wrap Residual block to handle SelfSelfAttention outputs
self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i])
# Modify ViT's forward function
self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer)
return
def encode_text(self, text: list):
prompts = [f'a photo of a {cls}.' for cls in text]
tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device)
text_embedding = self.model.encode_text(tokenized_prompts)
text_embedding = F.normalize(text_embedding, dim=-1)
return text_embedding.unsqueeze(0)
def min_max(self, logits):
B, num_prompt = logits.shape[:2]
logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
logits = (logits - logits_min) / (logits_max - logits_min)
return logits
def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
"""
:param image: torch.Tensor [1, 3, H, W]
:param text: list[]
:param normalize: bool - if True performs min-max normalization
:param return_ori: bool - if True uses the features from the original visual encoder
"""
# Image
W, H = image.shape[-2:]
feat_gem, feat_ori = self.model.visual(image)
image_feat = feat_ori if return_ori else feat_gem
image_feat = F.normalize(image_feat, dim=-1) # [1, N, dim]
# Text
text_embeddings = self.encode_text(text) # [1, num_prompt, dim]
# Image-Text matching
img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [1, N, num_prompt]
img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
w=W//self.patch_size, h=H//self.patch_size) # [1, num_prompt, w, h]
# Interpolate
img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [1, num_prompt, W, H]
# Heat Maps
if normalize:
img_txt_matching = self.min_max(img_txt_matching)
return img_txt_matching
def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
"""
:param image: torch.Tensor [B, 3, H, W]
:param text: list[list[]]
:param normalize: bool - if True performs min-max normalization
:param return_ori: bool - if True uses the features from the original visual encoder
"""
L = len(text)
cumm_idx = np.cumsum([len(t) for t in text]).tolist()
B, _, W, H = image.shape
assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.'
# Image
feat_gem, feat_ori = self.model.visual(image)
image_feat = feat_ori if return_ori else feat_gem
image_feat = F.normalize(image_feat, dim=-1) # [B, N, dim]
# Text
flatten_text = [t for sub_text in text for t in sub_text]
text_embeddings = self.encode_text(flatten_text) # [B, num_prompt, dim]
# Image-Text matching
img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [B, N, num_prompt]
img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
w=W // self.patch_size, h=H // self.patch_size) # [B, num_prompt, w, h]
# Interpolate
img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [B,num_prompt, W, H]
# Heat Maps
if normalize:
img_txt_matching = self.min_max(img_txt_matching) # [B,num_prompt, W, H]
# unflatten
img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1)
img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)]
return img_txt_matching