# -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- from functools import partial import torch from torch._C import Value import torch.nn as nn import numpy as np from timm.models.vision_transformer import PatchEmbed, Block from transformers import EncoderDecoderModel, BertTokenizer, AutoTokenizer from torch import einsum, nn import torch.nn.functional as F from einops import rearrange, repeat import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.CrossEntropyLoss): ''' Focal loss for classification tasks on imbalanced datasets ''' def __init__(self, gamma=1.0, alpha=None, ignore_index=-100, reduction='none'): super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none') self.reduction = reduction self.gamma = gamma def forward(self, input_, target): cross_entropy = super().forward(input_, target) # Temporarily mask out ignore index to '0' for valid gather-indices input. # This won't contribute final loss as the cross_entropy contribution # for these would be zero. target = target * (target != self.ignore_index).long() input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)).squeeze(1) loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy return torch.mean(loss) if self.reduction == 'mean' \ else torch.sum(loss) if self.reduction == 'sum' \ else loss # helper functions import math from functools import reduce def prob_mask_like(t, prob): return torch.zeros_like(t).float().uniform_(0, 1) < prob def mask_with_tokens(t, token_ids): init_no_mask = torch.full_like(t, False, dtype=torch.bool) mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) return mask def get_mask_subset_with_prob(mask, prob): batch, seq_len, device = *mask.shape, mask.device max_masked = math.ceil(prob * seq_len) num_tokens = mask.sum(dim=-1, keepdim=True) mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) mask_excess = mask_excess[:, :max_masked] rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) _, sampled_indices = rand.topk(max_masked, dim=-1) sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) new_mask = torch.zeros((batch, seq_len + 1), device=device) new_mask.scatter_(-1, sampled_indices, 1) return new_mask[:, 1:].bool() def exists(val): return val is not None def default(val, d): return val if exists(val) else d # normalization # they use layernorm without bias, something that pytorch does not offer class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) self.register_buffer("beta", torch.zeros(dim)) def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) # residual class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x # rotary positional embedding # https://arxiv.org/abs/2104.09864 class RotaryEmbedding(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) freqs = einsum("i , j -> i j", seq, self.inv_freq) return torch.cat((freqs, freqs), dim=-1) def rotate_half(x): x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin()) # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GELU for gating the feedforward # https://arxiv.org/abs/2002.05202 class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x # parallel attention and feedforward with residual # discovered by Wang et al + EleutherAI from GPT-J fame class ParallelTransformerBlock(nn.Module): def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, attn_drop_rate=0.0): super().__init__() self.norm = LayerNorm(dim) attn_inner_dim = dim_head * heads ff_inner_dim = dim * ff_mult self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) self.heads = heads self.scale = dim_head**-0.5 self.rotary_emb = RotaryEmbedding(dim_head) self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) self.ff_out = nn.Sequential( SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False) ) self.attn_drop_rate = attn_drop_rate # for caching causal mask and rotary embeddings self.register_buffer("mask", None, persistent=False) self.register_buffer("pos_emb", None, persistent=False) def get_mask(self, n, device): if self.mask is not None and self.mask.shape[-1] >= n: return self.mask[:n, :n] mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) self.register_buffer("mask", mask, persistent=False) return mask def get_rotary_embedding(self, n, device): if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: return self.pos_emb[:n] pos_emb = self.rotary_emb(n, device=device) self.register_buffer("pos_emb", pos_emb, persistent=False) return pos_emb def forward(self, x, attn_mask=None): """ Performs self attention and feedforward einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ n, device, h = x.shape[1], x.device, self.heads # pre layernorm x = self.norm(x) # attention queries, keys, values, and feedforward inner q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) # split heads # they use multi-query single-key-value attention, yet another Noam Shazeer paper # they found no performance loss past a certain scale, and more efficient decoding obviously # https://arxiv.org/abs/1911.02150 q = rearrange(q, "b n (h d) -> b h n d", h=h) # rotary embeddings positions = self.get_rotary_embedding(n, device) q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) # scale q = q * self.scale # similarity sim = einsum("b h i d, b j d -> b h i j", q, k) # causal mask causal_mask = self.get_mask(n, device) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # extra attention mask - for masking out attention from text CLS token to padding if exists(attn_mask): attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) if self.attn_drop_rate != 0.: # import ipdb; ipdb.set_trace() drop_ind = sim != -torch.finfo(sim.dtype).max dropout_mask = torch.cuda.FloatTensor(*sim[drop_ind].shape).uniform_() > self.attn_drop_rate sim[drop_ind] = sim[drop_ind].masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) # attention sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum("b h i j, b j d -> b h i d", attn, v) # merge heads out = rearrange(out, "b h n d -> b n (h d)") return self.attn_out(out) + self.ff_out(ff) # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward class CrossAttention(nn.Module): def __init__( self, dim, *, context_dim=None, dim_head=64, heads=8, parallel_ff=False, ff_mult=4, norm_context=False, dropout=0.0, ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 inner_dim = heads * dim_head context_dim = default(context_dim, dim) self.norm = LayerNorm(dim) self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) self.dropout = dropout # whether to have parallel feedforward ff_inner_dim = ff_mult * dim self.ff = nn.Sequential( nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False) ) if parallel_ff else None def forward(self, x, context): """ Use text and query, and image as kv einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ # pre-layernorm, for queries and context x = self.norm(x) context = self.context_norm(context) # get queries q = self.to_q(x) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) # scale q = q * self.scale # get key / values k, v = self.to_kv(context).chunk(2, dim=-1) # query / key similarity sim = einsum('b h i d, b j d -> b h i j', q, k) # dropout if self.training: dropout_mask = torch.cuda.FloatTensor(*sim.shape).uniform_() > self.dropout sim = sim.masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) # attention sim = sim - sim.amax(dim=-1, keepdim=True) attn = sim.softmax(dim=-1) # aggregate out = einsum('b h i j, b j d -> b h i d', attn, v) # merge and combine heads out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) # add parallel feedforward (for multimodal layers) if exists(self.ff): out = out + self.ff(x) return out def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class MaskedAutoencoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True, unimodal_depth=2, multimodal_depth=8, dim_head=64,heads=8, ff_mult=4, extract_multi_level=False, use_focal_loss=False, focal_gamma=1.0, less_u=False, use_weak_negative=False, use_label_smooth=False, ls_coef=0.1, use_maximum_entropy=False, ce_additional=False, use_word_weights=False, use_token_pos=False, use_expect_k=False, use_top_k=False, mae_decoder_caption=False, decoder_slot_depth=2, disable_decoder_vis_token_grad=False, cross_attn_dropout=0.0, predict_next_k_words=False, next_k=3, masked_text=False, masked_text_ratio=0.25, text_length=70, projector_layer=0, uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4, text_drop_attn=0.): super().__init__() # -------------------------------------------------------------------------- # MAE encoder specifics self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # -------------------------------------------------------------------------- # -------------------------------------------------------------------------- # MAE decoder specifics self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding self.mae_decoder_depth = decoder_depth self.mae_decoder_caption = mae_decoder_caption self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(decoder_depth)]) if self.mae_decoder_caption: self.decoder_slot_layers = nn.ModuleList([]) for _ in range(decoder_slot_depth): self.decoder_slot_layers.append( Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)), # Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)) ) self.decoder_caption_proj = nn.Linear(decoder_embed_dim, embed_dim) self.disable_decoder_vis_token_grad = disable_decoder_vis_token_grad self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder # -------------------------------------------------------------------------- self.norm_pix_loss = norm_pix_loss # -------------------------------------------------------------------------- # captioner specifics # unimodal layer is for text tokens. # multimodal layer is for text to query from image. self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", ) # token embeddings # NOTE: +1 for mask token used by MLM objective # self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim) self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim) self.text_cls_token = nn.Parameter(torch.randn(uni_dim)) self.embed_dim = embed_dim self.uni_dim = uni_dim #import ipdb; ipdb.set_trace() # unimodal layers # TODO: search on the four parameters # uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4 self.text_drop_attn = text_drop_attn self.unimodal_layers = nn.ModuleList([]) for _ in range(unimodal_depth): self.unimodal_layers.append( Residual(ParallelTransformerBlock(dim=uni_dim, dim_head=uni_dim_head, heads=uni_heads, ff_mult=uni_ff_mult, attn_drop_rate=self.text_drop_attn)), ) self.need_uni_2_mul_proj = False if uni_dim != embed_dim: self.need_uni_2_mul_proj = True self.uni_2_mul_proj = nn.Linear(uni_dim, embed_dim) # multimodal layers self.multimodal_layers = nn.ModuleList([]) self.less_u = less_u if less_u: for _ in range(multimodal_depth): self.multimodal_layers.append(nn.ModuleList([ Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)), Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) ])) else: for _ in range(multimodal_depth): self.multimodal_layers.append(nn.ModuleList([ Residual(ParallelTransformerBlock(dim=embed_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) ])) # to logits: for softmax caption loss self.to_logits = nn.Sequential( LayerNorm(embed_dim), nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) ) self.ce_additional = ce_additional if ce_additional: # to logits: for other losses self.to_logits_1 = nn.Sequential( LayerNorm(embed_dim), nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) ) nn.init.normal_(self.token_emb.weight, std=0.02) self.pad_id = 0 self.cls_id = 101 self.sep_id = 102 self.logsoftmax = nn.LogSoftmax(dim=1) self.extract_multi_level = extract_multi_level if self.extract_multi_level: self.projectors = nn.ModuleList([nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Linear(embed_dim // 2, embed_dim), norm_layer(embed_dim) ) for _ in [2, 5, 8,]]) # -------------------------------------------------------------------------- self.use_focal_loss = use_focal_loss self.use_weak_negative = use_weak_negative self.use_label_smooth = use_label_smooth self.ls_coef = ls_coef self.use_entropy = use_maximum_entropy self.use_word_weights = use_word_weights self.use_token_pos = use_token_pos self.predict_next_k_words = predict_next_k_words self.next_k = next_k self.pad = torch.nn.ReplicationPad1d((0, self.next_k-1)) self.use_expect_k = use_expect_k self.use_top_k = use_top_k if self.use_word_weights or self.use_token_pos: self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='none') else: self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='mean') self.masked_text = masked_text self.masked_text_ratio = masked_text_ratio # self.text_mask_token = nn.Parameter(torch.randn(embed_dim)) self.mask_token_id = len(self.tokenizer.vocab) # self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False) self.text_length = text_length self.latent_projector_layer = projector_layer if self.latent_projector_layer != 0: self.latent_projector = [ nn.Linear(embed_dim, embed_dim), nn.ReLU() ] * (self.latent_projector_layer - 1) self.latent_projector.append(nn.Linear(embed_dim, embed_dim)) self.latent_projector = nn.Sequential(*self.latent_projector) self.initialize_weights() def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # text_pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, ) # torch.nn.init.xavier_normal_(self.text_position_embed) # learnable text position embedding # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=.02) torch.nn.init.normal_(self.mask_token, std=.02) # torch.nn.init.normal_(self.text_mask_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_encoder(self, x, mask_ratio): # embed patches x = self.patch_embed(x) # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] # masking: length -> length * mask_ratio x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) if self.extract_multi_level: multi_level_feats = [] # apply Transformer blocks for blk_idx, blk in enumerate(self.blocks): x = blk(x) if blk_idx in [2, 5, 8]: multi_level_feats.append(self.projectors[[2,5,8].index(blk_idx)](x)) x = self.norm(x) multi_level_feats.append(x) return multi_level_feats, mask, ids_restore # apply Transformer blocks for blk_idx, blk in enumerate(self.blocks): x = blk(x) x = self.norm(x) return x, mask, ids_restore, ids_keep def forward_decoder(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # non_mask_token = x # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks decoder_feat = [] for idx, blk in enumerate(self.decoder_blocks): x = blk(x) if idx == self.mae_decoder_depth // 2: decoder_feat.append(x) x = self.decoder_norm(x) # use the output from decoder to do captioning # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] return x, decoder_feat def forward_loss(self, imgs, pred, mask): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def embed_text(self, text): batch, device = text.shape[0], text.device seq = text.shape[1] text_tokens = self.token_emb(text) # append text cls tokens text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch) text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) # create specific mask for text cls token at the end # to prevent it from attending to padding cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) # go through unimodal layers for attn_ff in self.unimodal_layers: text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) if self.need_uni_2_mul_proj: text_tokens = self.uni_2_mul_proj(text_tokens) # get text cls token text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] return text_tokens def forward(self, imgs, caption_ids=None, attention_mask=None, mask_ratio=0.75, freeze_bert=False, teacher_forcing=False, caption_only=False, encoder_only=False, word_weights=None, syn_count=None): latent, mask, ids_restore, ids_keep = self.forward_encoder(imgs, mask_ratio) if not caption_only: pred, decoder_feat = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] mae_loss = self.forward_loss(imgs, pred, mask) else: mae_loss = 0. if self.latent_projector_layer != 0: latent = self.latent_projector(latent) # latent: visual info: N, L, C # caption_ids: N, Len text, labels = caption_ids[:, :-1], caption_ids[:, 1:] seq = text.shape[1] text_tokens = self.embed_text(text) # N, Len, C # create specific mask for text cls token at the end # to prevent it from attending to padding cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) unimodal_text_tokens = text_tokens if not self.less_u: for attn_ff, cross_attn in self.multimodal_layers: text_tokens = attn_ff(text_tokens, attn_mask=attn_mask[:, :-1, :-1]) text_tokens = cross_attn(text_tokens, latent) else: # dim, num_head, for cross_attn1, cross_attn2 in self.multimodal_layers: text_tokens = cross_attn1(text_tokens, latent) text_tokens = cross_attn2(text_tokens, latent) logits = self.to_logits(text_tokens) # N, Len, NVocab logits = logits.reshape(-1, len(self.tokenizer.vocab)) labels = labels.reshape(-1) caption_loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id,) return mae_loss, caption_loss, None def mae_vit_small_patch16_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=384, depth=12, num_heads=6, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_base_patch16_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_large_patch16_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_huge_patch14_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=14, embed_dim=1280, depth=32, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model # set recommended archs mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks