MUG_caption / model.py
tennant's picture
add location
71e1a05
raw
history blame
32.4 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import torch
from torch._C import Value
import torch.nn as nn
import numpy as np
from timm.models.vision_transformer import PatchEmbed, Block
from transformers import EncoderDecoderModel, BertTokenizer, AutoTokenizer
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.CrossEntropyLoss):
''' Focal loss for classification tasks on imbalanced datasets '''
def __init__(self, gamma=1.0, alpha=None, ignore_index=-100, reduction='none'):
super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none')
self.reduction = reduction
self.gamma = gamma
def forward(self, input_, target):
cross_entropy = super().forward(input_, target)
# Temporarily mask out ignore index to '0' for valid gather-indices input.
# This won't contribute final loss as the cross_entropy contribution
# for these would be zero.
target = target * (target != self.ignore_index).long()
input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)).squeeze(1)
loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy
return torch.mean(loss) if self.reduction == 'mean' \
else torch.sum(loss) if self.reduction == 'sum' \
else loss
# helper functions
import math
from functools import reduce
def prob_mask_like(t, prob):
return torch.zeros_like(t).float().uniform_(0, 1) < prob
def mask_with_tokens(t, token_ids):
init_no_mask = torch.full_like(t, False, dtype=torch.bool)
mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
return mask
def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)
num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# residual
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GELU for gating the feedforward
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, attn_drop_rate=0.0):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
self.attn_drop_rate = attn_drop_rate
# for caching causal mask and rotary embeddings
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x, attn_mask=None):
"""
Performs self attention and feedforward
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# scale
q = q * self.scale
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k)
# causal mask
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# extra attention mask - for masking out attention from text CLS token to padding
if exists(attn_mask):
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
if self.attn_drop_rate != 0.:
# import ipdb; ipdb.set_trace()
drop_ind = sim != -torch.finfo(sim.dtype).max
dropout_mask = torch.cuda.FloatTensor(*sim[drop_ind].shape).uniform_() > self.attn_drop_rate
sim[drop_ind] = sim[drop_ind].masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False,
dropout=0.0,
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = dropout
# whether to have parallel feedforward
ff_inner_dim = ff_mult * dim
self.ff = nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
) if parallel_ff else None
def forward(self, x, context):
"""
Use text and query, and image as kv
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# pre-layernorm, for queries and context
x = self.norm(x)
context = self.context_norm(context)
# get queries
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# scale
q = q * self.scale
# get key / values
k, v = self.to_kv(context).chunk(2, dim=-1)
# query / key similarity
sim = einsum('b h i d, b j d -> b h i j', q, k)
# dropout
if self.training:
dropout_mask = torch.cuda.FloatTensor(*sim.shape).uniform_() > self.dropout
sim = sim.masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True)
attn = sim.softmax(dim=-1)
# aggregate
out = einsum('b h i j, b j d -> b h i d', attn, v)
# merge and combine heads
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
# add parallel feedforward (for multimodal layers)
if exists(self.ff):
out = out + self.ff(x)
return out
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True,
unimodal_depth=2, multimodal_depth=8, dim_head=64,heads=8,
ff_mult=4, extract_multi_level=False, use_focal_loss=False, focal_gamma=1.0,
less_u=False, use_weak_negative=False, use_label_smooth=False, ls_coef=0.1,
use_maximum_entropy=False, ce_additional=False, use_word_weights=False, use_token_pos=False,
use_expect_k=False, use_top_k=False, mae_decoder_caption=False, decoder_slot_depth=2, disable_decoder_vis_token_grad=False,
cross_attn_dropout=0.0, predict_next_k_words=False, next_k=3, masked_text=False, masked_text_ratio=0.25, text_length=70,
projector_layer=0, uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4, text_drop_attn=0.):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.mae_decoder_depth = decoder_depth
self.mae_decoder_caption = mae_decoder_caption
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)])
if self.mae_decoder_caption:
self.decoder_slot_layers = nn.ModuleList([])
for _ in range(decoder_slot_depth):
self.decoder_slot_layers.append(
Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)),
# Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,))
)
self.decoder_caption_proj = nn.Linear(decoder_embed_dim, embed_dim)
self.disable_decoder_vis_token_grad = disable_decoder_vis_token_grad
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
# --------------------------------------------------------------------------
# captioner specifics
# unimodal layer is for text tokens.
# multimodal layer is for text to query from image.
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", )
# token embeddings
# NOTE: +1 for mask token used by MLM objective
# self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim)
self.text_cls_token = nn.Parameter(torch.randn(uni_dim))
self.embed_dim = embed_dim
self.uni_dim = uni_dim
#import ipdb; ipdb.set_trace()
# unimodal layers
# TODO: search on the four parameters
# uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4
self.text_drop_attn = text_drop_attn
self.unimodal_layers = nn.ModuleList([])
for _ in range(unimodal_depth):
self.unimodal_layers.append(
Residual(ParallelTransformerBlock(dim=uni_dim, dim_head=uni_dim_head,
heads=uni_heads, ff_mult=uni_ff_mult, attn_drop_rate=self.text_drop_attn)),
)
self.need_uni_2_mul_proj = False
if uni_dim != embed_dim:
self.need_uni_2_mul_proj = True
self.uni_2_mul_proj = nn.Linear(uni_dim, embed_dim)
# multimodal layers
self.multimodal_layers = nn.ModuleList([])
self.less_u = less_u
if less_u:
for _ in range(multimodal_depth):
self.multimodal_layers.append(nn.ModuleList([
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)),
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout))
]))
else:
for _ in range(multimodal_depth):
self.multimodal_layers.append(nn.ModuleList([
Residual(ParallelTransformerBlock(dim=embed_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout))
]))
# to logits: for softmax caption loss
self.to_logits = nn.Sequential(
LayerNorm(embed_dim),
nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False)
)
self.ce_additional = ce_additional
if ce_additional:
# to logits: for other losses
self.to_logits_1 = nn.Sequential(
LayerNorm(embed_dim),
nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False)
)
nn.init.normal_(self.token_emb.weight, std=0.02)
self.pad_id = 0
self.cls_id = 101
self.sep_id = 102
self.logsoftmax = nn.LogSoftmax(dim=1)
self.extract_multi_level = extract_multi_level
if self.extract_multi_level:
self.projectors = nn.ModuleList([nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.GELU(),
nn.Linear(embed_dim // 2, embed_dim),
norm_layer(embed_dim)
) for _ in [2, 5, 8,]])
# --------------------------------------------------------------------------
self.use_focal_loss = use_focal_loss
self.use_weak_negative = use_weak_negative
self.use_label_smooth = use_label_smooth
self.ls_coef = ls_coef
self.use_entropy = use_maximum_entropy
self.use_word_weights = use_word_weights
self.use_token_pos = use_token_pos
self.predict_next_k_words = predict_next_k_words
self.next_k = next_k
self.pad = torch.nn.ReplicationPad1d((0, self.next_k-1))
self.use_expect_k = use_expect_k
self.use_top_k = use_top_k
if self.use_word_weights or self.use_token_pos:
self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='none')
else:
self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='mean')
self.masked_text = masked_text
self.masked_text_ratio = masked_text_ratio
# self.text_mask_token = nn.Parameter(torch.randn(embed_dim))
self.mask_token_id = len(self.tokenizer.vocab)
# self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False)
self.text_length = text_length
self.latent_projector_layer = projector_layer
if self.latent_projector_layer != 0:
self.latent_projector = [
nn.Linear(embed_dim, embed_dim),
nn.ReLU()
] * (self.latent_projector_layer - 1)
self.latent_projector.append(nn.Linear(embed_dim, embed_dim))
self.latent_projector = nn.Sequential(*self.latent_projector)
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# text_pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, )
# torch.nn.init.xavier_normal_(self.text_position_embed) # learnable text position embedding
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
# torch.nn.init.normal_(self.text_mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.extract_multi_level:
multi_level_feats = []
# apply Transformer blocks
for blk_idx, blk in enumerate(self.blocks):
x = blk(x)
if blk_idx in [2, 5, 8]:
multi_level_feats.append(self.projectors[[2,5,8].index(blk_idx)](x))
x = self.norm(x)
multi_level_feats.append(x)
return multi_level_feats, mask, ids_restore
# apply Transformer blocks
for blk_idx, blk in enumerate(self.blocks):
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore, ids_keep
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# non_mask_token = x
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
decoder_feat = []
for idx, blk in enumerate(self.decoder_blocks):
x = blk(x)
if idx == self.mae_decoder_depth // 2:
decoder_feat.append(x)
x = self.decoder_norm(x)
# use the output from decoder to do captioning
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x, decoder_feat
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def embed_text(self, text):
batch, device = text.shape[0], text.device
seq = text.shape[1]
text_tokens = self.token_emb(text)
# append text cls tokens
text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch)
text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)
# create specific mask for text cls token at the end
# to prevent it from attending to padding
cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
# go through unimodal layers
for attn_ff in self.unimodal_layers:
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)
if self.need_uni_2_mul_proj:
text_tokens = self.uni_2_mul_proj(text_tokens)
# get text cls token
text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
return text_tokens
def forward(self, imgs, caption_ids=None, attention_mask=None, mask_ratio=0.75,
freeze_bert=False, teacher_forcing=False, caption_only=False,
encoder_only=False, word_weights=None, syn_count=None):
latent, mask, ids_restore, ids_keep = self.forward_encoder(imgs, mask_ratio)
if not caption_only:
pred, decoder_feat = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
mae_loss = self.forward_loss(imgs, pred, mask)
else:
mae_loss = 0.
if self.latent_projector_layer != 0:
latent = self.latent_projector(latent)
# latent: visual info: N, L, C
# caption_ids: N, Len
text, labels = caption_ids[:, :-1], caption_ids[:, 1:]
seq = text.shape[1]
text_tokens = self.embed_text(text) # N, Len, C
# create specific mask for text cls token at the end
# to prevent it from attending to padding
cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
unimodal_text_tokens = text_tokens
if not self.less_u:
for attn_ff, cross_attn in self.multimodal_layers:
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask[:, :-1, :-1])
text_tokens = cross_attn(text_tokens, latent)
else:
# dim, num_head,
for cross_attn1, cross_attn2 in self.multimodal_layers:
text_tokens = cross_attn1(text_tokens, latent)
text_tokens = cross_attn2(text_tokens, latent)
logits = self.to_logits(text_tokens) # N, Len, NVocab
logits = logits.reshape(-1, len(self.tokenizer.vocab))
labels = labels.reshape(-1)
caption_loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id,)
return mae_loss, caption_loss, None
def mae_vit_small_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_base_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_huge_patch14_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# set recommended archs
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks