import abc import math import torch import torch.nn.functional as F from sentence_transformers import SentenceTransformer from timm.models.vision_transformer import ( VisionTransformer, build_model_with_cfg, checkpoint_filter_fn, checkpoint_seq, resolve_pretrained_cfg, ) from torch import Tensor, nn class BlankLayer(nn.Module): pass class CustomViT(VisionTransformer): def __init__( self, *args, image_pooling="gmp", **kwargs, ): super(CustomViT, self).__init__( *args, **kwargs ) self.image_pooling = image_pooling def forward_head(self, x, pre_logits: bool = False): if self.image_pooling: if self.image_pooling == "gap": x = x[:, self.num_prefix_tokens:].mean(dim=1) elif self.image_pooling == "gmp": x = x[:, self.num_prefix_tokens:].max(dim=-2)[0] elif self.image_pooling == "all": x = x[:, self.num_prefix_tokens:] else: # cls by default x = x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) def forward(self, x, get_pos_tokens=False): x = self.forward_features(x, get_pos_tokens=get_pos_tokens) if get_pos_tokens: return self.fc_norm(x[:, self.num_prefix_tokens:]) x = self.forward_head(x) return x def forward_features(self, x, get_pos_tokens=False): _, nc, h, w = x.shape x = self.patch_embed(x) x = self._pos_embed(x, w, h) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def _pos_embed(self, x, w, h): if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + self.pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self._interpolate_pos_encoding(x, w, h) return self.pos_drop(x) def _interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed class_pos_embed = self.pos_embed[:, 0] patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_embed.patch_size[0] h0 = h // self.patch_embed.patch_size[1] # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( 0, 3, 1, 2 ), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode="bicubic", ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get("features_only", None): raise RuntimeError("features_only not implemented for Vision Transformer models.") pretrained_cfg = resolve_pretrained_cfg( variant, pretrained_cfg=kwargs.pop("pretrained_cfg", None) ) model = build_model_with_cfg( CustomViT, variant, pretrained, pretrained_cfg=pretrained_cfg, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load="npz" in pretrained_cfg["url"], **kwargs, ) return model def vit_base_patch16_224(pretrained=False, variant="vit_base_patch16_224_dino", **kwargs): """ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294""" model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer(variant, pretrained=pretrained, **model_kwargs) return model class CLIPpyModel(abc.ABC, torch.nn.Module): """ Implements code for running inference with pre-trained CLIPpy model. NOTE: weights used are for a model trained with lower batch-size leading to results below those in paper. """ def __init__( self, image_pooling: str = "cls", text_pooling: str = "gap", ): super().__init__() self.visual = BlankLayer() self.visual.trunk = vit_base_patch16_224(True, image_pooling=image_pooling) self.text = SentenceTransformer("sentence-transformers/sentence-t5-base") self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07)) self.set_text_pooling(text_pooling) self._divisor_eps = 1e-4 self._image_pooling = image_pooling self._text_pooling = text_pooling def forward( self, images: Tensor, input_ids: Tensor, input_id_masks: Tensor, get_pos_tokens: bool = False, **kwargs, ): image_encodings = self.encode_image(images, get_pos_tokens=get_pos_tokens) if get_pos_tokens: return { image_encodings: image_encodings, } text_encodings = self.encode_text(input_ids, input_id_masks) return { image_encodings: image_encodings, text_encodings: text_encodings, } def encode_text(self, input_ids: Tensor, input_id_masks: Tensor = None, **kwargs): output = self.text({"input_ids": input_ids, "attention_mask": input_id_masks})[ "sentence_embedding" ] return self.text_head(output) def text_head(self, hidden_states: Tensor, input_id_masks: Tensor = None, **kwargs): return F.normalize(hidden_states, dim=-1, eps=self._divisor_eps).float() def encode_image(self, images: Tensor, get_pos_tokens: bool = False, **kwargs): output = self.visual.trunk(images, get_pos_tokens) return self.image_head(output, get_pos_tokens=get_pos_tokens) def image_head(self, hidden_states: Tensor, get_pos_tokens: bool = False, **kwargs): return F.normalize(hidden_states, dim=-1, eps=self._divisor_eps).float() def set_text_pooling(self, pooling): """ Converts pooling in the Hugging Face model to be max or average pooling""" if pooling == "gmp": self.text[1].pooling_mode_mean_tokens = False self.text[1].pooling_mode_max_tokens = True elif pooling == "gap": pass else: raise NotImplementedError(f"{pooling} not implemented")