# Copyright (c) Facebook, Inc. and its affiliates. from typing import Union, List from collections import OrderedDict import torch from torch import nn import torch from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer __all__ = ["tokenize"] class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential( *[ResidualAttentionBlock(width, heads, attn_mask) \ for _ in range(layers)]) def forward(self, x: torch.Tensor): return self.resblocks(x) class CLIPTEXT(nn.Module): def __init__(self, embed_dim=512, # text context_length=77, vocab_size=49408, transformer_width=512, transformer_heads=8, transformer_layers=12 ): super().__init__() self._tokenizer = _Tokenizer() self.context_length = context_length self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask() ) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() def initialize_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @property def device(self): return self.text_projection.device @property def dtype(self): return self.text_projection.dtype def tokenize(self, texts: Union[str, List[str]], \ context_length: int = 77) -> torch.LongTensor: """ """ if isinstance(texts, str): texts = [texts] sot_token = self._tokenizer.encoder["<|startoftext|>"] eot_token = self._tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: st = torch.randint( len(tokens) - context_length + 1, (1,))[0].item() tokens = tokens[st: st + context_length] # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result def encode_text(self, text): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x def forward(self, captions): ''' captions: list of strings ''' text = self.tokenize(captions).to(self.device) # B x L x D features = self.encode_text(text) # B x D return features def build_text_encoder(pretrain=True, visual_type="RN50"): clip_dict = { "visual_type": ["embed_dim", "context_length", "vocab_size", "transformer_width", "transformer_heads", "transformer_layers"], "RN50": [1024, 77, 49408, 512, 8, 12], "RN50x4": [640, 77, 49408, 640, 10, 12], "RN50x16": [768, 77, 49408, 768, 12, 12], "RN50x64": [1024, 77, 49408, 1024, 16, 12], } text_encoder = CLIPTEXT(**{k: v for k, v in zip(clip_dict['visual_type'], clip_dict[visual_type])}) if pretrain: import clip if visual_type in clip_dict: pretrained_model, _ = clip.load(visual_type, device='cpu') else: raise NotImplementedError state_dict = pretrained_model.state_dict() to_delete_keys = ["logit_scale", "input_resolution", \ "context_length", "vocab_size"] + \ [k for k in state_dict.keys() if k.startswith('visual.')] for k in to_delete_keys: if k in state_dict: del state_dict[k] # print('Loading pretrained CLIP') text_encoder.load_state_dict(state_dict) return text_encoder