Spaces:
Running
Running
| # ------------------------------------------------------------------------- | |
| # MIT License | |
| # | |
| # Copyright (c) 2021 OpenAI | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # | |
| # ------------------------------------------------------------------------- | |
| import torch | |
| import torch.utils.checkpoint as checkpoint | |
| from torch import nn | |
| from collections import OrderedDict | |
| from timm.models.layers import trunc_normal_ | |
| class Attention(nn.Module): | |
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, q, k, v): | |
| B, N, C = q.shape | |
| assert k.shape == v.shape | |
| B, M, C = k.shape | |
| q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) | |
| k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) | |
| v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) | |
| attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| nhead, | |
| dropout=0.1, | |
| ): | |
| super().__init__() | |
| self.self_attn = Attention(d_model, nhead, proj_drop=dropout) | |
| self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(d_model, d_model * 4), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_model * 4, d_model) | |
| ) | |
| def forward(self, x, mem): | |
| q = k = v = self.norm1(x) | |
| x = x + self.self_attn(q, k, v) | |
| q = self.norm2(x) | |
| x = x + self.cross_attn(q, mem, mem) | |
| x = x + self.dropout(self.mlp(self.norm3(x))) | |
| return x | |
| class ContextDecoder(nn.Module): | |
| def __init__(self, | |
| transformer_width=256, | |
| transformer_heads=4, | |
| transformer_layers=6, | |
| visual_dim=1024, | |
| dropout=0.1, | |
| **kwargs): | |
| super().__init__() | |
| self.memory_proj = nn.Sequential( | |
| nn.LayerNorm(visual_dim), | |
| nn.Linear(visual_dim, transformer_width), | |
| nn.LayerNorm(transformer_width), | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.LayerNorm(visual_dim), | |
| nn.Linear(visual_dim, transformer_width), | |
| ) | |
| self.decoder = nn.ModuleList([ | |
| TransformerDecoderLayer(transformer_width, transformer_heads, dropout) for _ in range(transformer_layers) | |
| ]) | |
| self.out_proj = nn.Sequential( | |
| nn.LayerNorm(transformer_width), | |
| nn.Linear(transformer_width, visual_dim) | |
| ) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, text, visual): | |
| B, N, C = visual.shape | |
| visual = self.memory_proj(visual) | |
| x = self.text_proj(text) | |
| for layer in self.decoder: | |
| x = layer(x, visual) | |
| return self.out_proj(x) | |
| class QuickGELU(nn.Module): | |
| def forward(self, x: torch.Tensor): | |
| return x * torch.sigmoid(1.702 * x) | |
| class ResidualAttentionBlock(nn.Module): | |
| def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(d_model, n_head) | |
| self.ln_1 = nn.LayerNorm(d_model) | |
| self.mlp = nn.Sequential( | |
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()), | |
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |
| self.ln_2 = nn.LayerNorm(d_model) | |
| self.attn_mask = attn_mask | |
| def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor): | |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None | |
| return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0] | |
| def forward(self, x: torch.Tensor, key_padding_mask=None): | |
| x = x + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False): | |
| super().__init__() | |
| self.width = width | |
| self.layers = layers | |
| self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) | |
| proj_std = (self.width**-0.5) * ((2 * self.layers)**-0.5) | |
| attn_std = self.width**-0.5 | |
| fc_std = (2 * self.width)**-0.5 | |
| for block in self.resblocks: | |
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
| self.use_checkpoint = use_checkpoint | |
| def forward(self, x: torch.Tensor): | |
| for resblock in self.resblocks: | |
| if self.use_checkpoint: | |
| x = checkpoint.checkpoint(resblock, x) | |
| else: | |
| x = resblock(x) | |
| return x | |
| class TextTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| context_length: int, | |
| width: int, | |
| layers: int, | |
| vocab_size, | |
| use_checkpoint=False, | |
| ): | |
| super().__init__() | |
| heads = width // 64 | |
| self.context_length = context_length | |
| self.width = width | |
| self.transformer = Transformer( | |
| width=width, | |
| layers=layers, | |
| heads=heads, | |
| attn_mask=self.build_attention_mask(), | |
| use_checkpoint=use_checkpoint) | |
| self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) | |
| self.ln_final = nn.LayerNorm(width) | |
| self.token_embedding = nn.Embedding(vocab_size, width) | |
| nn.init.normal_(self.token_embedding.weight, std=0.02) | |
| # initialization | |
| nn.init.normal_(self.positional_embedding, std=0.01) | |
| def build_attention_mask(self): | |
| # lazily create causal attention mask, with full attention between the vision tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(self.context_length, self.context_length) | |
| mask.fill_(float('-inf')) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| def forward(self, text): | |
| x = self.token_embedding(text) | |
| x = x + self.positional_embedding | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_final(x) | |
| # x.shape = [batch_size, n_ctx, transformer.width] | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] | |
| return x |