import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from open_clip.transformer import VisionTransformer from .gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward class GEMWrapper(nn.Module): def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None): super(GEMWrapper, self).__init__() self.model = model self.tokenizer = tokenizer self.depth = depth self.ss_attn_iter = ss_attn_iter self.ss_attn_temp = ss_attn_temp self.patch_size = self.model.visual.patch_size[0] self.apply_gem() def apply_gem(self): for i in range(1, self.depth): # Extract info from the original ViT num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads) qkv_bias = True # Init the self-self attention layer ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp) # Copy necessary weights ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone() ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone() ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone() ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone() # Swap the original Attention with our SelfSelfAttention self.model.visual.transformer.resblocks[-i].attn = ss_attn # Wrap Residual block to handle SelfSelfAttention outputs self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i]) # Modify ViT's forward function self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer) return def encode_text(self, text: list): prompts = [f'a photo of a {cls}.' for cls in text] tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device) text_embedding = self.model.encode_text(tokenized_prompts) text_embedding = F.normalize(text_embedding, dim=-1) return text_embedding.unsqueeze(0) def min_max(self, logits): B, num_prompt = logits.shape[:2] logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1) logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1) logits = (logits - logits_min) / (logits_max - logits_min) return logits def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False): """ :param image: torch.Tensor [1, 3, H, W] :param text: list[] :param normalize: bool - if True performs min-max normalization :param return_ori: bool - if True uses the features from the original visual encoder """ # Image W, H = image.shape[-2:] feat_gem, feat_ori = self.model.visual(image) image_feat = feat_ori if return_ori else feat_gem image_feat = F.normalize(image_feat, dim=-1) # [1, N, dim] # Text text_embeddings = self.encode_text(text) # [1, num_prompt, dim] # Image-Text matching img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [1, N, num_prompt] img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h', w=W//self.patch_size, h=H//self.patch_size) # [1, num_prompt, w, h] # Interpolate img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [1, num_prompt, W, H] # Heat Maps if normalize: img_txt_matching = self.min_max(img_txt_matching) return img_txt_matching def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False): """ :param image: torch.Tensor [B, 3, H, W] :param text: list[list[]] :param normalize: bool - if True performs min-max normalization :param return_ori: bool - if True uses the features from the original visual encoder """ L = len(text) cumm_idx = np.cumsum([len(t) for t in text]).tolist() B, _, W, H = image.shape assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.' # Image feat_gem, feat_ori = self.model.visual(image) image_feat = feat_ori if return_ori else feat_gem image_feat = F.normalize(image_feat, dim=-1) # [B, N, dim] # Text flatten_text = [t for sub_text in text for t in sub_text] text_embeddings = self.encode_text(flatten_text) # [B, num_prompt, dim] # Image-Text matching img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [B, N, num_prompt] img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h', w=W // self.patch_size, h=H // self.patch_size) # [B, num_prompt, w, h] # Interpolate img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [B,num_prompt, W, H] # Heat Maps if normalize: img_txt_matching = self.min_max(img_txt_matching) # [B,num_prompt, W, H] # unflatten img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1) img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)] return img_txt_matching