# ------------------------------------------------------------------------- # MIT License # # Copyright (c) 2021 OpenAI # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # ------------------------------------------------------------------------- import torch import torch.utils.checkpoint as checkpoint from torch import nn from collections import OrderedDict from timm.models.layers import trunc_normal_ class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, q, k, v): B, N, C = q.shape assert k.shape == v.shape B, M, C = k.shape q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale attn = attn.softmax(dim=-1) x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dropout=0.1, ): super().__init__() self.self_attn = Attention(d_model, nhead, proj_drop=dropout) self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.mlp = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model) ) def forward(self, x, mem): q = k = v = self.norm1(x) x = x + self.self_attn(q, k, v) q = self.norm2(x) x = x + self.cross_attn(q, mem, mem) x = x + self.dropout(self.mlp(self.norm3(x))) return x class ContextDecoder(nn.Module): def __init__(self, transformer_width=256, transformer_heads=4, transformer_layers=6, visual_dim=1024, dropout=0.1, **kwargs): super().__init__() self.memory_proj = nn.Sequential( nn.LayerNorm(visual_dim), nn.Linear(visual_dim, transformer_width), nn.LayerNorm(transformer_width), ) self.text_proj = nn.Sequential( nn.LayerNorm(visual_dim), nn.Linear(visual_dim, transformer_width), ) self.decoder = nn.ModuleList([ TransformerDecoderLayer(transformer_width, transformer_heads, dropout) for _ in range(transformer_layers) ]) self.out_proj = nn.Sequential( nn.LayerNorm(transformer_width), nn.Linear(transformer_width, visual_dim) ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, text, visual): B, N, C = visual.shape visual = self.memory_proj(visual) x = self.text_proj(text) for layer in self.decoder: x = layer(x, visual) return self.out_proj(x) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = nn.LayerNorm(d_model) self.mlp = nn.Sequential( OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()), ('c_proj', nn.Linear(d_model * 4, d_model))])) self.ln_2 = nn.LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0] def forward(self, x: torch.Tensor, key_padding_mask=None): x = x + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) proj_std = (self.width**-0.5) * ((2 * self.layers)**-0.5) attn_std = self.width**-0.5 fc_std = (2 * self.width)**-0.5 for block in self.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) self.use_checkpoint = use_checkpoint def forward(self, x: torch.Tensor): for resblock in self.resblocks: if self.use_checkpoint: x = checkpoint.checkpoint(resblock, x) else: x = resblock(x) return x class TextTransformer(nn.Module): def __init__( self, context_length: int, width: int, layers: int, vocab_size, use_checkpoint=False, ): super().__init__() heads = width // 64 self.context_length = context_length self.width = width self.transformer = Transformer( width=width, layers=layers, heads=heads, attn_mask=self.build_attention_mask(), use_checkpoint=use_checkpoint) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) self.ln_final = nn.LayerNorm(width) self.token_embedding = nn.Embedding(vocab_size, width) nn.init.normal_(self.token_embedding.weight, std=0.02) # initialization nn.init.normal_(self.positional_embedding, std=0.01) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float('-inf')) mask.triu_(1) # zero out the lower diagonal return mask def forward(self, text): x = self.token_embedding(text) x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] return x