from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .backbone import Backbone from .build import BACKBONE_REGISTRY from .det_swin import SwinTransformer from ..text_encoder import build_text_encoder from ..text_encoder import build_tokenizer class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): pdtype = x.dtype x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x.to(pdtype) + self.bias class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path: float = 0.0): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.drop_path(self.attention(self.ln_1(x))) x = x + self.drop_path(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__(self, context_length: int, vocab_size: int, width: int, layers: int, heads: int, drop_path: float = 0.0): super().__init__() self.token_embedding = nn.Embedding(vocab_size, width) self.context_length = context_length self.positional_embedding = nn.Parameter( torch.empty(self.context_length, width) ) self.width = width self.layers = layers attn_mask = self.build_attention_mask() dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule self.resblocks = nn.Sequential( *[ ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers) ] ) self.ln_final = LayerNorm(width) trunc_normal_(self.positional_embedding, std=.02) # nn.init.normal_(self.token_embedding, std=.02) trunc_normal_(self.token_embedding.weight, std=.02) self.apply(self._init_weights) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) @torch.jit.ignore def no_weight_decay(self): return { 'positional_embedding', 'token_embedding', } def forward(self, text: torch.Tensor): x = self.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.resblocks(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] return x class CLIP(Backbone): def __init__(self, config: dict): super().__init__() spec_text = config['MODEL']['SPEC']['TEXT'] assert spec_text['TOKENIZER'] == 'clip', 'Only support clip tokenizer' self.tokenizer_style = spec_text['TOKENIZER'] self.tokenizer = build_tokenizer(spec_text) self.text_encoder = build_text_encoder(spec_text, self.tokenizer, True) embed_dim = config['MODEL']['SPEC']['EMBED_DIM'] self.text_projection = nn.Parameter( torch.empty(spec_text['WIDTH'], embed_dim) ) spec_vision = config['MODEL']['SPEC']['VISION'] self.image_encoder = SwinTransformer( patch_size=spec_vision['PATCH_SIZE'], in_chans=spec_vision['IN_CHANS'], embed_dim=spec_vision['EMBED_DIM'], depths=spec_vision['DEPTHS'], num_heads=spec_vision['NUM_HEADS'], window_size=spec_vision['WINDOW_SIZE'], mlp_ratio=spec_vision['MLP_RATIO'], qkv_bias=spec_vision['QKV_BIAS'], qk_scale=spec_vision.get('QK_SCALE', None), drop_rate=spec_vision['DROP_RATE'], attn_drop_rate=spec_vision['ATTN_DROP_RATE'], drop_path_rate=spec_vision['DROP_PATH_RATE'], ape=spec_vision['APE'], patch_norm=spec_vision['PATCH_NORM'], out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ) width = spec_vision['EMBED_DIM'] * 2 ** (len(spec_vision['DEPTHS']) - 1) self.image_projection = nn.Parameter( torch.empty(width, embed_dim) ) # self.logit_scale = nn.Parameter(torch.FloatTensor([np.log(1 / 0.07)])) self.logit_scale = nn.Parameter(torch.ones([])) trunc_normal_(self.text_projection, std=.02) trunc_normal_(self.image_projection, std=.02) def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logger.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] is '*' ) if need_init: if verbose: logging.info(f'=> init {k} from {pretrained}') need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): no_weight_decay = {'logit_scale'} for k in self.text_encoder.no_weight_decay(): no_weight_decay.add('text.'+k) for k in self.image_encoder.no_weight_decay(): no_weight_decay.add('visual.'+k) return no_weight_decay @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} @property def dtype(self): return self.image_encoder.conv1.weight.dtype def encode_image(self, image, norm=True): x = self.image_encoder(image) return x def encode_text(self, text, norm=True): assert isinstance(text, str), "only support single query" tokens = self.tokenizer( text, padding='max_length', truncation=True, max_length=77, return_tensors='pt' ) tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} x = self.text_encoder(**tokens) x = x['last_hidden_state'] x = x[torch.arange(x.size(0)), tokens['input_ids'].argmax(dim=-1)] x = x @ self.text_projection if norm: x = x / x.norm(dim=-1, keepdim=True) return x def forward(self, image): features_image = self.image_encoder(image) return features_image @BACKBONE_REGISTRY.register() def build_clip_swin_backbone(cfg, input_shape): """ Create a CLIP Swin instance from config. Returns: SwinTransformer: a :class:`SwinTransformer` instance. """ spec_vision = cfg.MODEL.CLIP.VISION return SwinTransformer( patch_size=spec_vision['PATCH_SIZE'], in_chans=spec_vision['IN_CHANS'], embed_dim=spec_vision['EMBED_DIM'], depths=spec_vision['DEPTHS'], num_heads=spec_vision['NUM_HEADS'], window_size=spec_vision['WINDOW_SIZE'], mlp_ratio=spec_vision['MLP_RATIO'], qkv_bias=spec_vision['QKV_BIAS'], qk_scale=spec_vision.get('QK_SCALE', None), drop_rate=spec_vision['DROP_RATE'], attn_drop_rate=spec_vision['ATTN_DROP_RATE'], drop_path_rate=spec_vision['DROP_PATH_RATE'], ape=spec_vision['APE'], patch_norm=spec_vision['PATCH_NORM'], out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, ) @BACKBONE_REGISTRY.register() def build_clip_swin(cfg, input_shape): """ Create a CLIP Swin instance from config. Returns: SwinTransformer: a :class:`SwinTransformer` instance. """ return CLIP(cfg)