Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# DeiT: https://github.com/facebookresearch/deit | |
# -------------------------------------------------------- | |
from functools import partial | |
import torch | |
from torch._C import Value | |
import torch.nn as nn | |
import numpy as np | |
from timm.models.vision_transformer import PatchEmbed, Block | |
from transformers import EncoderDecoderModel, BertTokenizer, AutoTokenizer | |
from torch import einsum, nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FocalLoss(nn.CrossEntropyLoss): | |
''' Focal loss for classification tasks on imbalanced datasets ''' | |
def __init__(self, gamma=1.0, alpha=None, ignore_index=-100, reduction='none'): | |
super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none') | |
self.reduction = reduction | |
self.gamma = gamma | |
def forward(self, input_, target): | |
cross_entropy = super().forward(input_, target) | |
# Temporarily mask out ignore index to '0' for valid gather-indices input. | |
# This won't contribute final loss as the cross_entropy contribution | |
# for these would be zero. | |
target = target * (target != self.ignore_index).long() | |
input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)).squeeze(1) | |
loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy | |
return torch.mean(loss) if self.reduction == 'mean' \ | |
else torch.sum(loss) if self.reduction == 'sum' \ | |
else loss | |
# helper functions | |
import math | |
from functools import reduce | |
def prob_mask_like(t, prob): | |
return torch.zeros_like(t).float().uniform_(0, 1) < prob | |
def mask_with_tokens(t, token_ids): | |
init_no_mask = torch.full_like(t, False, dtype=torch.bool) | |
mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) | |
return mask | |
def get_mask_subset_with_prob(mask, prob): | |
batch, seq_len, device = *mask.shape, mask.device | |
max_masked = math.ceil(prob * seq_len) | |
num_tokens = mask.sum(dim=-1, keepdim=True) | |
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) | |
mask_excess = mask_excess[:, :max_masked] | |
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) | |
_, sampled_indices = rand.topk(max_masked, dim=-1) | |
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) | |
new_mask = torch.zeros((batch, seq_len + 1), device=device) | |
new_mask.scatter_(-1, sampled_indices, 1) | |
return new_mask[:, 1:].bool() | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
# normalization | |
# they use layernorm without bias, something that pytorch does not offer | |
class LayerNorm(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.ones(dim)) | |
self.register_buffer("beta", torch.zeros(dim)) | |
def forward(self, x): | |
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) | |
# residual | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
# rotary positional embedding | |
# https://arxiv.org/abs/2104.09864 | |
class RotaryEmbedding(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
def forward(self, max_seq_len, *, device): | |
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) | |
freqs = einsum("i , j -> i j", seq, self.inv_freq) | |
return torch.cat((freqs, freqs), dim=-1) | |
def rotate_half(x): | |
x = rearrange(x, "... (j d) -> ... j d", j=2) | |
x1, x2 = x.unbind(dim=-2) | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(pos, t): | |
return (t * pos.cos()) + (rotate_half(t) * pos.sin()) | |
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GELU for gating the feedforward | |
# https://arxiv.org/abs/2002.05202 | |
class SwiGLU(nn.Module): | |
def forward(self, x): | |
x, gate = x.chunk(2, dim=-1) | |
return F.silu(gate) * x | |
# parallel attention and feedforward with residual | |
# discovered by Wang et al + EleutherAI from GPT-J fame | |
class ParallelTransformerBlock(nn.Module): | |
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, attn_drop_rate=0.0): | |
super().__init__() | |
self.norm = LayerNorm(dim) | |
attn_inner_dim = dim_head * heads | |
ff_inner_dim = dim * ff_mult | |
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) | |
self.heads = heads | |
self.scale = dim_head**-0.5 | |
self.rotary_emb = RotaryEmbedding(dim_head) | |
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) | |
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) | |
self.ff_out = nn.Sequential( | |
SwiGLU(), | |
nn.Linear(ff_inner_dim, dim, bias=False) | |
) | |
self.attn_drop_rate = attn_drop_rate | |
# for caching causal mask and rotary embeddings | |
self.register_buffer("mask", None, persistent=False) | |
self.register_buffer("pos_emb", None, persistent=False) | |
def get_mask(self, n, device): | |
if self.mask is not None and self.mask.shape[-1] >= n: | |
return self.mask[:n, :n] | |
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) | |
self.register_buffer("mask", mask, persistent=False) | |
return mask | |
def get_rotary_embedding(self, n, device): | |
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: | |
return self.pos_emb[:n] | |
pos_emb = self.rotary_emb(n, device=device) | |
self.register_buffer("pos_emb", pos_emb, persistent=False) | |
return pos_emb | |
def forward(self, x, attn_mask=None): | |
""" | |
Performs self attention and feedforward | |
einstein notation | |
b - batch | |
h - heads | |
n, i, j - sequence length (base sequence length, source, target) | |
d - feature dimension | |
""" | |
n, device, h = x.shape[1], x.device, self.heads | |
# pre layernorm | |
x = self.norm(x) | |
# attention queries, keys, values, and feedforward inner | |
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) | |
# split heads | |
# they use multi-query single-key-value attention, yet another Noam Shazeer paper | |
# they found no performance loss past a certain scale, and more efficient decoding obviously | |
# https://arxiv.org/abs/1911.02150 | |
q = rearrange(q, "b n (h d) -> b h n d", h=h) | |
# rotary embeddings | |
positions = self.get_rotary_embedding(n, device) | |
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) | |
# scale | |
q = q * self.scale | |
# similarity | |
sim = einsum("b h i d, b j d -> b h i j", q, k) | |
# causal mask | |
causal_mask = self.get_mask(n, device) | |
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | |
# extra attention mask - for masking out attention from text CLS token to padding | |
if exists(attn_mask): | |
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') | |
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) | |
if self.attn_drop_rate != 0.: | |
# import ipdb; ipdb.set_trace() | |
drop_ind = sim != -torch.finfo(sim.dtype).max | |
dropout_mask = torch.cuda.FloatTensor(*sim[drop_ind].shape).uniform_() > self.attn_drop_rate | |
sim[drop_ind] = sim[drop_ind].masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) | |
# attention | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
attn = sim.softmax(dim=-1) | |
# aggregate values | |
out = einsum("b h i j, b j d -> b h i d", attn, v) | |
# merge heads | |
out = rearrange(out, "b h n d -> b n (h d)") | |
return self.attn_out(out) + self.ff_out(ff) | |
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward | |
class CrossAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
*, | |
context_dim=None, | |
dim_head=64, | |
heads=8, | |
parallel_ff=False, | |
ff_mult=4, | |
norm_context=False, | |
dropout=0.0, | |
): | |
super().__init__() | |
self.heads = heads | |
self.scale = dim_head ** -0.5 | |
inner_dim = heads * dim_head | |
context_dim = default(context_dim, dim) | |
self.norm = LayerNorm(dim) | |
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) | |
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
self.dropout = dropout | |
# whether to have parallel feedforward | |
ff_inner_dim = ff_mult * dim | |
self.ff = nn.Sequential( | |
nn.Linear(dim, ff_inner_dim * 2, bias=False), | |
SwiGLU(), | |
nn.Linear(ff_inner_dim, dim, bias=False) | |
) if parallel_ff else None | |
def forward(self, x, context): | |
""" | |
Use text and query, and image as kv | |
einstein notation | |
b - batch | |
h - heads | |
n, i, j - sequence length (base sequence length, source, target) | |
d - feature dimension | |
""" | |
# pre-layernorm, for queries and context | |
x = self.norm(x) | |
context = self.context_norm(context) | |
# get queries | |
q = self.to_q(x) | |
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) | |
# scale | |
q = q * self.scale | |
# get key / values | |
k, v = self.to_kv(context).chunk(2, dim=-1) | |
# query / key similarity | |
sim = einsum('b h i d, b j d -> b h i j', q, k) | |
# dropout | |
if self.training: | |
dropout_mask = torch.cuda.FloatTensor(*sim.shape).uniform_() > self.dropout | |
sim = sim.masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) | |
# attention | |
sim = sim - sim.amax(dim=-1, keepdim=True) | |
attn = sim.softmax(dim=-1) | |
# aggregate | |
out = einsum('b h i j, b j d -> b h i d', attn, v) | |
# merge and combine heads | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
out = self.to_out(out) | |
# add parallel feedforward (for multimodal layers) | |
if exists(self.ff): | |
out = out + self.ff(x) | |
return out | |
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
""" | |
grid_size: int of the grid height and width | |
return: | |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
grid_h = np.arange(grid_size, dtype=np.float32) | |
grid_w = np.arange(grid_size, dtype=np.float32) | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size, grid_size]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float) | |
omega /= embed_dim / 2. | |
omega = 1. / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class MaskedAutoencoderViT(nn.Module): | |
""" Masked Autoencoder with VisionTransformer backbone | |
""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, | |
embed_dim=1024, depth=24, num_heads=16, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True, | |
unimodal_depth=2, multimodal_depth=8, dim_head=64,heads=8, | |
ff_mult=4, extract_multi_level=False, use_focal_loss=False, focal_gamma=1.0, | |
less_u=False, use_weak_negative=False, use_label_smooth=False, ls_coef=0.1, | |
use_maximum_entropy=False, ce_additional=False, use_word_weights=False, use_token_pos=False, | |
use_expect_k=False, use_top_k=False, mae_decoder_caption=False, decoder_slot_depth=2, disable_decoder_vis_token_grad=False, | |
cross_attn_dropout=0.0, predict_next_k_words=False, next_k=3, masked_text=False, masked_text_ratio=0.25, text_length=70, | |
projector_layer=0, uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4, text_drop_attn=0.): | |
super().__init__() | |
# -------------------------------------------------------------------------- | |
# MAE encoder specifics | |
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) | |
num_patches = self.patch_embed.num_patches | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
self.blocks = nn.ModuleList([ | |
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
for i in range(depth)]) | |
self.norm = norm_layer(embed_dim) | |
# -------------------------------------------------------------------------- | |
# -------------------------------------------------------------------------- | |
# MAE decoder specifics | |
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding | |
self.mae_decoder_depth = decoder_depth | |
self.mae_decoder_caption = mae_decoder_caption | |
self.decoder_blocks = nn.ModuleList([ | |
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
for i in range(decoder_depth)]) | |
if self.mae_decoder_caption: | |
self.decoder_slot_layers = nn.ModuleList([]) | |
for _ in range(decoder_slot_depth): | |
self.decoder_slot_layers.append( | |
Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)), | |
# Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)) | |
) | |
self.decoder_caption_proj = nn.Linear(decoder_embed_dim, embed_dim) | |
self.disable_decoder_vis_token_grad = disable_decoder_vis_token_grad | |
self.decoder_norm = norm_layer(decoder_embed_dim) | |
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder | |
# -------------------------------------------------------------------------- | |
self.norm_pix_loss = norm_pix_loss | |
# -------------------------------------------------------------------------- | |
# captioner specifics | |
# unimodal layer is for text tokens. | |
# multimodal layer is for text to query from image. | |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", ) | |
# token embeddings | |
# NOTE: +1 for mask token used by MLM objective | |
# self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim) | |
self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim) | |
self.text_cls_token = nn.Parameter(torch.randn(uni_dim)) | |
self.embed_dim = embed_dim | |
self.uni_dim = uni_dim | |
#import ipdb; ipdb.set_trace() | |
# unimodal layers | |
# TODO: search on the four parameters | |
# uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4 | |
self.text_drop_attn = text_drop_attn | |
self.unimodal_layers = nn.ModuleList([]) | |
for _ in range(unimodal_depth): | |
self.unimodal_layers.append( | |
Residual(ParallelTransformerBlock(dim=uni_dim, dim_head=uni_dim_head, | |
heads=uni_heads, ff_mult=uni_ff_mult, attn_drop_rate=self.text_drop_attn)), | |
) | |
self.need_uni_2_mul_proj = False | |
if uni_dim != embed_dim: | |
self.need_uni_2_mul_proj = True | |
self.uni_2_mul_proj = nn.Linear(uni_dim, embed_dim) | |
# multimodal layers | |
self.multimodal_layers = nn.ModuleList([]) | |
self.less_u = less_u | |
if less_u: | |
for _ in range(multimodal_depth): | |
self.multimodal_layers.append(nn.ModuleList([ | |
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)), | |
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) | |
])) | |
else: | |
for _ in range(multimodal_depth): | |
self.multimodal_layers.append(nn.ModuleList([ | |
Residual(ParallelTransformerBlock(dim=embed_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), | |
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) | |
])) | |
# to logits: for softmax caption loss | |
self.to_logits = nn.Sequential( | |
LayerNorm(embed_dim), | |
nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) | |
) | |
self.ce_additional = ce_additional | |
if ce_additional: | |
# to logits: for other losses | |
self.to_logits_1 = nn.Sequential( | |
LayerNorm(embed_dim), | |
nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) | |
) | |
nn.init.normal_(self.token_emb.weight, std=0.02) | |
self.pad_id = 0 | |
self.cls_id = 101 | |
self.sep_id = 102 | |
self.logsoftmax = nn.LogSoftmax(dim=1) | |
self.extract_multi_level = extract_multi_level | |
if self.extract_multi_level: | |
self.projectors = nn.ModuleList([nn.Sequential( | |
nn.Linear(embed_dim, embed_dim // 2), | |
nn.GELU(), | |
nn.Linear(embed_dim // 2, embed_dim), | |
norm_layer(embed_dim) | |
) for _ in [2, 5, 8,]]) | |
# -------------------------------------------------------------------------- | |
self.use_focal_loss = use_focal_loss | |
self.use_weak_negative = use_weak_negative | |
self.use_label_smooth = use_label_smooth | |
self.ls_coef = ls_coef | |
self.use_entropy = use_maximum_entropy | |
self.use_word_weights = use_word_weights | |
self.use_token_pos = use_token_pos | |
self.predict_next_k_words = predict_next_k_words | |
self.next_k = next_k | |
self.pad = torch.nn.ReplicationPad1d((0, self.next_k-1)) | |
self.use_expect_k = use_expect_k | |
self.use_top_k = use_top_k | |
if self.use_word_weights or self.use_token_pos: | |
self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='none') | |
else: | |
self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='mean') | |
self.masked_text = masked_text | |
self.masked_text_ratio = masked_text_ratio | |
# self.text_mask_token = nn.Parameter(torch.randn(embed_dim)) | |
self.mask_token_id = len(self.tokenizer.vocab) | |
# self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False) | |
self.text_length = text_length | |
self.latent_projector_layer = projector_layer | |
if self.latent_projector_layer != 0: | |
self.latent_projector = [ | |
nn.Linear(embed_dim, embed_dim), | |
nn.ReLU() | |
] * (self.latent_projector_layer - 1) | |
self.latent_projector.append(nn.Linear(embed_dim, embed_dim)) | |
self.latent_projector = nn.Sequential(*self.latent_projector) | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialization | |
# initialize (and freeze) pos_embed by sin-cos embedding | |
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) | |
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) | |
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) | |
# text_pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, ) | |
# torch.nn.init.xavier_normal_(self.text_position_embed) # learnable text position embedding | |
# initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.cls_token, std=.02) | |
torch.nn.init.normal_(self.mask_token, std=.02) | |
# torch.nn.init.normal_(self.text_mask_token, std=.02) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def patchify(self, imgs): | |
""" | |
imgs: (N, 3, H, W) | |
x: (N, L, patch_size**2 *3) | |
""" | |
p = self.patch_embed.patch_size[0] | |
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 | |
h = w = imgs.shape[2] // p | |
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) | |
x = torch.einsum('nchpwq->nhwpqc', x) | |
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) | |
return x | |
def unpatchify(self, x): | |
""" | |
x: (N, L, patch_size**2 *3) | |
imgs: (N, 3, H, W) | |
""" | |
p = self.patch_embed.patch_size[0] | |
h = w = int(x.shape[1]**.5) | |
assert h * w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) | |
x = torch.einsum('nhwpqc->nchpwq', x) | |
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) | |
return imgs | |
def random_masking(self, x, mask_ratio): | |
""" | |
Perform per-sample random masking by per-sample shuffling. | |
Per-sample shuffling is done by argsort random noise. | |
x: [N, L, D], sequence | |
""" | |
N, L, D = x.shape # batch, length, dim | |
len_keep = int(L * (1 - mask_ratio)) | |
noise = torch.rand(N, L, device=x.device) # noise in [0, 1] | |
# sort noise for each sample | |
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([N, L], device=x.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return x_masked, mask, ids_restore, ids_keep | |
def forward_encoder(self, x, mask_ratio): | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + self.pos_embed[:, 1:, :] | |
# masking: length -> length * mask_ratio | |
x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio) | |
# append cls token | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
if self.extract_multi_level: | |
multi_level_feats = [] | |
# apply Transformer blocks | |
for blk_idx, blk in enumerate(self.blocks): | |
x = blk(x) | |
if blk_idx in [2, 5, 8]: | |
multi_level_feats.append(self.projectors[[2,5,8].index(blk_idx)](x)) | |
x = self.norm(x) | |
multi_level_feats.append(x) | |
return multi_level_feats, mask, ids_restore | |
# apply Transformer blocks | |
for blk_idx, blk in enumerate(self.blocks): | |
x = blk(x) | |
x = self.norm(x) | |
return x, mask, ids_restore, ids_keep | |
def forward_decoder(self, x, ids_restore): | |
# embed tokens | |
x = self.decoder_embed(x) | |
# non_mask_token = x | |
# append mask tokens to sequence | |
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token | |
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle | |
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token | |
# add pos embed | |
x = x + self.decoder_pos_embed | |
# apply Transformer blocks | |
decoder_feat = [] | |
for idx, blk in enumerate(self.decoder_blocks): | |
x = blk(x) | |
if idx == self.mae_decoder_depth // 2: | |
decoder_feat.append(x) | |
x = self.decoder_norm(x) | |
# use the output from decoder to do captioning | |
# predictor projection | |
x = self.decoder_pred(x) | |
# remove cls token | |
x = x[:, 1:, :] | |
return x, decoder_feat | |
def forward_loss(self, imgs, pred, mask): | |
""" | |
imgs: [N, 3, H, W] | |
pred: [N, L, p*p*3] | |
mask: [N, L], 0 is keep, 1 is remove, | |
""" | |
target = self.patchify(imgs) | |
if self.norm_pix_loss: | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
target = (target - mean) / (var + 1.e-6)**.5 | |
loss = (pred - target) ** 2 | |
loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
return loss | |
def embed_text(self, text): | |
batch, device = text.shape[0], text.device | |
seq = text.shape[1] | |
text_tokens = self.token_emb(text) | |
# append text cls tokens | |
text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch) | |
text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) | |
# create specific mask for text cls token at the end | |
# to prevent it from attending to padding | |
cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') | |
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) | |
# go through unimodal layers | |
for attn_ff in self.unimodal_layers: | |
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) | |
if self.need_uni_2_mul_proj: | |
text_tokens = self.uni_2_mul_proj(text_tokens) | |
# get text cls token | |
text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] | |
return text_tokens | |
def forward(self, imgs, caption_ids=None, attention_mask=None, mask_ratio=0.75, | |
freeze_bert=False, teacher_forcing=False, caption_only=False, | |
encoder_only=False, word_weights=None, syn_count=None): | |
latent, mask, ids_restore, ids_keep = self.forward_encoder(imgs, mask_ratio) | |
if not caption_only: | |
pred, decoder_feat = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] | |
mae_loss = self.forward_loss(imgs, pred, mask) | |
else: | |
mae_loss = 0. | |
if self.latent_projector_layer != 0: | |
latent = self.latent_projector(latent) | |
# latent: visual info: N, L, C | |
# caption_ids: N, Len | |
text, labels = caption_ids[:, :-1], caption_ids[:, 1:] | |
seq = text.shape[1] | |
text_tokens = self.embed_text(text) # N, Len, C | |
# create specific mask for text cls token at the end | |
# to prevent it from attending to padding | |
cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') | |
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) | |
unimodal_text_tokens = text_tokens | |
if not self.less_u: | |
for attn_ff, cross_attn in self.multimodal_layers: | |
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask[:, :-1, :-1]) | |
text_tokens = cross_attn(text_tokens, latent) | |
else: | |
# dim, num_head, | |
for cross_attn1, cross_attn2 in self.multimodal_layers: | |
text_tokens = cross_attn1(text_tokens, latent) | |
text_tokens = cross_attn2(text_tokens, latent) | |
logits = self.to_logits(text_tokens) # N, Len, NVocab | |
logits = logits.reshape(-1, len(self.tokenizer.vocab)) | |
labels = labels.reshape(-1) | |
caption_loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id,) | |
return mae_loss, caption_loss, None | |
def mae_vit_small_patch16_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=16, embed_dim=384, depth=12, num_heads=6, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
return model | |
def mae_vit_base_patch16_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=16, embed_dim=768, depth=12, num_heads=12, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
return model | |
def mae_vit_large_patch16_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=16, embed_dim=1024, depth=24, num_heads=16, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
return model | |
def mae_vit_huge_patch14_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=14, embed_dim=1280, depth=32, num_heads=16, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
return model | |
# set recommended archs | |
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b | |
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks | |