|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
|
import torch |
|
from torch._C import Value |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from timm.models.vision_transformer import PatchEmbed, Block |
|
from transformers import EncoderDecoderModel, BertTokenizer, AutoTokenizer |
|
|
|
|
|
from torch import einsum, nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class FocalLoss(nn.CrossEntropyLoss): |
|
''' Focal loss for classification tasks on imbalanced datasets ''' |
|
|
|
def __init__(self, gamma=1.0, alpha=None, ignore_index=-100, reduction='none'): |
|
super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none') |
|
self.reduction = reduction |
|
self.gamma = gamma |
|
|
|
def forward(self, input_, target): |
|
cross_entropy = super().forward(input_, target) |
|
|
|
|
|
|
|
target = target * (target != self.ignore_index).long() |
|
input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)).squeeze(1) |
|
loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy |
|
return torch.mean(loss) if self.reduction == 'mean' \ |
|
else torch.sum(loss) if self.reduction == 'sum' \ |
|
else loss |
|
|
|
|
|
|
|
|
|
import math |
|
from functools import reduce |
|
|
|
def prob_mask_like(t, prob): |
|
return torch.zeros_like(t).float().uniform_(0, 1) < prob |
|
|
|
def mask_with_tokens(t, token_ids): |
|
init_no_mask = torch.full_like(t, False, dtype=torch.bool) |
|
mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) |
|
return mask |
|
|
|
def get_mask_subset_with_prob(mask, prob): |
|
batch, seq_len, device = *mask.shape, mask.device |
|
max_masked = math.ceil(prob * seq_len) |
|
|
|
num_tokens = mask.sum(dim=-1, keepdim=True) |
|
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) |
|
mask_excess = mask_excess[:, :max_masked] |
|
|
|
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) |
|
_, sampled_indices = rand.topk(max_masked, dim=-1) |
|
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) |
|
|
|
new_mask = torch.zeros((batch, seq_len + 1), device=device) |
|
new_mask.scatter_(-1, sampled_indices, 1) |
|
return new_mask[:, 1:].bool() |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.gamma = nn.Parameter(torch.ones(dim)) |
|
self.register_buffer("beta", torch.zeros(dim)) |
|
|
|
def forward(self, x): |
|
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) |
|
|
|
|
|
class Residual(nn.Module): |
|
def __init__(self, fn): |
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return self.fn(x, *args, **kwargs) + x |
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
def forward(self, max_seq_len, *, device): |
|
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) |
|
freqs = einsum("i , j -> i j", seq, self.inv_freq) |
|
return torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
def rotate_half(x): |
|
x = rearrange(x, "... (j d) -> ... j d", j=2) |
|
x1, x2 = x.unbind(dim=-2) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(pos, t): |
|
return (t * pos.cos()) + (rotate_half(t) * pos.sin()) |
|
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module): |
|
def forward(self, x): |
|
x, gate = x.chunk(2, dim=-1) |
|
return F.silu(gate) * x |
|
|
|
|
|
|
|
|
|
class ParallelTransformerBlock(nn.Module): |
|
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, attn_drop_rate=0.0): |
|
super().__init__() |
|
self.norm = LayerNorm(dim) |
|
|
|
attn_inner_dim = dim_head * heads |
|
ff_inner_dim = dim * ff_mult |
|
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) |
|
|
|
self.heads = heads |
|
self.scale = dim_head**-0.5 |
|
self.rotary_emb = RotaryEmbedding(dim_head) |
|
|
|
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) |
|
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) |
|
|
|
self.ff_out = nn.Sequential( |
|
SwiGLU(), |
|
nn.Linear(ff_inner_dim, dim, bias=False) |
|
) |
|
|
|
self.attn_drop_rate = attn_drop_rate |
|
|
|
|
|
|
|
self.register_buffer("mask", None, persistent=False) |
|
self.register_buffer("pos_emb", None, persistent=False) |
|
|
|
def get_mask(self, n, device): |
|
if self.mask is not None and self.mask.shape[-1] >= n: |
|
return self.mask[:n, :n] |
|
|
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) |
|
self.register_buffer("mask", mask, persistent=False) |
|
return mask |
|
|
|
def get_rotary_embedding(self, n, device): |
|
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: |
|
return self.pos_emb[:n] |
|
|
|
pos_emb = self.rotary_emb(n, device=device) |
|
self.register_buffer("pos_emb", pos_emb, persistent=False) |
|
return pos_emb |
|
|
|
def forward(self, x, attn_mask=None): |
|
""" |
|
Performs self attention and feedforward |
|
einstein notation |
|
b - batch |
|
h - heads |
|
n, i, j - sequence length (base sequence length, source, target) |
|
d - feature dimension |
|
""" |
|
|
|
n, device, h = x.shape[1], x.device, self.heads |
|
|
|
x = self.norm(x) |
|
|
|
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
q = rearrange(q, "b n (h d) -> b h n d", h=h) |
|
|
|
positions = self.get_rotary_embedding(n, device) |
|
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) |
|
|
|
q = q * self.scale |
|
|
|
sim = einsum("b h i d, b j d -> b h i j", q, k) |
|
|
|
causal_mask = self.get_mask(n, device) |
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) |
|
|
|
|
|
if exists(attn_mask): |
|
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') |
|
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) |
|
|
|
if self.attn_drop_rate != 0.: |
|
|
|
drop_ind = sim != -torch.finfo(sim.dtype).max |
|
dropout_mask = torch.cuda.FloatTensor(*sim[drop_ind].shape).uniform_() > self.attn_drop_rate |
|
sim[drop_ind] = sim[drop_ind].masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) |
|
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
|
attn = sim.softmax(dim=-1) |
|
|
|
out = einsum("b h i j, b j d -> b h i d", attn, v) |
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.attn_out(out) + self.ff_out(ff) |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
context_dim=None, |
|
dim_head=64, |
|
heads=8, |
|
parallel_ff=False, |
|
ff_mult=4, |
|
norm_context=False, |
|
dropout=0.0, |
|
): |
|
super().__init__() |
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
inner_dim = heads * dim_head |
|
context_dim = default(context_dim, dim) |
|
|
|
self.norm = LayerNorm(dim) |
|
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias=False) |
|
|
|
self.dropout = dropout |
|
|
|
|
|
ff_inner_dim = ff_mult * dim |
|
|
|
self.ff = nn.Sequential( |
|
nn.Linear(dim, ff_inner_dim * 2, bias=False), |
|
SwiGLU(), |
|
nn.Linear(ff_inner_dim, dim, bias=False) |
|
) if parallel_ff else None |
|
|
|
def forward(self, x, context): |
|
""" |
|
Use text and query, and image as kv |
|
einstein notation |
|
b - batch |
|
h - heads |
|
n, i, j - sequence length (base sequence length, source, target) |
|
d - feature dimension |
|
""" |
|
|
|
|
|
x = self.norm(x) |
|
context = self.context_norm(context) |
|
|
|
q = self.to_q(x) |
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) |
|
|
|
q = q * self.scale |
|
|
|
k, v = self.to_kv(context).chunk(2, dim=-1) |
|
|
|
sim = einsum('b h i d, b j d -> b h i j', q, k) |
|
|
|
|
|
if self.training: |
|
dropout_mask = torch.cuda.FloatTensor(*sim.shape).uniform_() > self.dropout |
|
sim = sim.masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) |
|
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True) |
|
attn = sim.softmax(dim=-1) |
|
|
|
out = einsum('b h i j, b j d -> b h i d', attn, v) |
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
out = self.to_out(out) |
|
|
|
if exists(self.ff): |
|
out = out + self.ff(x) |
|
return out |
|
|
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float32) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
class MaskedAutoencoderViT(nn.Module): |
|
""" Masked Autoencoder with VisionTransformer backbone |
|
""" |
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, |
|
embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True, |
|
unimodal_depth=2, multimodal_depth=8, dim_head=64,heads=8, |
|
ff_mult=4, extract_multi_level=False, use_focal_loss=False, focal_gamma=1.0, |
|
less_u=False, use_weak_negative=False, use_label_smooth=False, ls_coef=0.1, |
|
use_maximum_entropy=False, ce_additional=False, use_word_weights=False, use_token_pos=False, |
|
use_expect_k=False, use_top_k=False, mae_decoder_caption=False, decoder_slot_depth=2, disable_decoder_vis_token_grad=False, |
|
cross_attn_dropout=0.0, predict_next_k_words=False, next_k=3, masked_text=False, masked_text_ratio=0.25, text_length=70, |
|
projector_layer=0, uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4, text_drop_attn=0.): |
|
super().__init__() |
|
|
|
|
|
|
|
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
|
|
|
self.blocks = nn.ModuleList([ |
|
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(depth)]) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
|
|
|
|
|
|
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
|
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) |
|
|
|
self.mae_decoder_depth = decoder_depth |
|
self.mae_decoder_caption = mae_decoder_caption |
|
self.decoder_blocks = nn.ModuleList([ |
|
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(decoder_depth)]) |
|
|
|
if self.mae_decoder_caption: |
|
|
|
self.decoder_slot_layers = nn.ModuleList([]) |
|
for _ in range(decoder_slot_depth): |
|
self.decoder_slot_layers.append( |
|
Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)), |
|
|
|
) |
|
self.decoder_caption_proj = nn.Linear(decoder_embed_dim, embed_dim) |
|
self.disable_decoder_vis_token_grad = disable_decoder_vis_token_grad |
|
|
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) |
|
|
|
|
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", ) |
|
|
|
|
|
|
|
|
|
|
|
self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim) |
|
self.text_cls_token = nn.Parameter(torch.randn(uni_dim)) |
|
|
|
self.embed_dim = embed_dim |
|
self.uni_dim = uni_dim |
|
|
|
|
|
|
|
|
|
|
|
self.text_drop_attn = text_drop_attn |
|
self.unimodal_layers = nn.ModuleList([]) |
|
for _ in range(unimodal_depth): |
|
self.unimodal_layers.append( |
|
Residual(ParallelTransformerBlock(dim=uni_dim, dim_head=uni_dim_head, |
|
heads=uni_heads, ff_mult=uni_ff_mult, attn_drop_rate=self.text_drop_attn)), |
|
) |
|
|
|
self.need_uni_2_mul_proj = False |
|
if uni_dim != embed_dim: |
|
self.need_uni_2_mul_proj = True |
|
self.uni_2_mul_proj = nn.Linear(uni_dim, embed_dim) |
|
|
|
|
|
self.multimodal_layers = nn.ModuleList([]) |
|
self.less_u = less_u |
|
if less_u: |
|
for _ in range(multimodal_depth): |
|
self.multimodal_layers.append(nn.ModuleList([ |
|
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)), |
|
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) |
|
])) |
|
else: |
|
for _ in range(multimodal_depth): |
|
self.multimodal_layers.append(nn.ModuleList([ |
|
Residual(ParallelTransformerBlock(dim=embed_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), |
|
Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) |
|
])) |
|
|
|
|
|
self.to_logits = nn.Sequential( |
|
LayerNorm(embed_dim), |
|
nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) |
|
) |
|
|
|
self.ce_additional = ce_additional |
|
if ce_additional: |
|
|
|
self.to_logits_1 = nn.Sequential( |
|
LayerNorm(embed_dim), |
|
nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) |
|
) |
|
|
|
nn.init.normal_(self.token_emb.weight, std=0.02) |
|
|
|
self.pad_id = 0 |
|
self.cls_id = 101 |
|
self.sep_id = 102 |
|
|
|
self.logsoftmax = nn.LogSoftmax(dim=1) |
|
|
|
self.extract_multi_level = extract_multi_level |
|
if self.extract_multi_level: |
|
self.projectors = nn.ModuleList([nn.Sequential( |
|
nn.Linear(embed_dim, embed_dim // 2), |
|
nn.GELU(), |
|
nn.Linear(embed_dim // 2, embed_dim), |
|
norm_layer(embed_dim) |
|
) for _ in [2, 5, 8,]]) |
|
|
|
|
|
self.use_focal_loss = use_focal_loss |
|
|
|
self.use_weak_negative = use_weak_negative |
|
self.use_label_smooth = use_label_smooth |
|
self.ls_coef = ls_coef |
|
self.use_entropy = use_maximum_entropy |
|
self.use_word_weights = use_word_weights |
|
self.use_token_pos = use_token_pos |
|
|
|
self.predict_next_k_words = predict_next_k_words |
|
self.next_k = next_k |
|
self.pad = torch.nn.ReplicationPad1d((0, self.next_k-1)) |
|
|
|
self.use_expect_k = use_expect_k |
|
self.use_top_k = use_top_k |
|
|
|
if self.use_word_weights or self.use_token_pos: |
|
self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='none') |
|
else: |
|
self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='mean') |
|
|
|
self.masked_text = masked_text |
|
self.masked_text_ratio = masked_text_ratio |
|
|
|
self.mask_token_id = len(self.tokenizer.vocab) |
|
|
|
|
|
self.text_length = text_length |
|
|
|
self.latent_projector_layer = projector_layer |
|
if self.latent_projector_layer != 0: |
|
self.latent_projector = [ |
|
nn.Linear(embed_dim, embed_dim), |
|
nn.ReLU() |
|
] * (self.latent_projector_layer - 1) |
|
self.latent_projector.append(nn.Linear(embed_dim, embed_dim)) |
|
|
|
self.latent_projector = nn.Sequential(*self.latent_projector) |
|
|
|
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) |
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) |
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
|
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
torch.nn.init.normal_(self.cls_token, std=.02) |
|
torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def patchify(self, imgs): |
|
""" |
|
imgs: (N, 3, H, W) |
|
x: (N, L, patch_size**2 *3) |
|
""" |
|
p = self.patch_embed.patch_size[0] |
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
|
h = w = imgs.shape[2] // p |
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) |
|
return x |
|
|
|
def unpatchify(self, x): |
|
""" |
|
x: (N, L, patch_size**2 *3) |
|
imgs: (N, 3, H, W) |
|
""" |
|
p = self.patch_embed.patch_size[0] |
|
h = w = int(x.shape[1]**.5) |
|
assert h * w == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
|
return imgs |
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def forward_encoder(self, x, mask_ratio): |
|
|
|
x = self.patch_embed(x) |
|
|
|
|
|
x = x + self.pos_embed[:, 1:, :] |
|
|
|
|
|
x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio) |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
if self.extract_multi_level: |
|
multi_level_feats = [] |
|
|
|
for blk_idx, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
if blk_idx in [2, 5, 8]: |
|
multi_level_feats.append(self.projectors[[2,5,8].index(blk_idx)](x)) |
|
x = self.norm(x) |
|
multi_level_feats.append(x) |
|
|
|
return multi_level_feats, mask, ids_restore |
|
|
|
|
|
|
|
for blk_idx, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
x = self.norm(x) |
|
|
|
return x, mask, ids_restore, ids_keep |
|
|
|
def forward_decoder(self, x, ids_restore): |
|
|
|
x = self.decoder_embed(x) |
|
|
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
decoder_feat = [] |
|
for idx, blk in enumerate(self.decoder_blocks): |
|
x = blk(x) |
|
if idx == self.mae_decoder_depth // 2: |
|
decoder_feat.append(x) |
|
|
|
x = self.decoder_norm(x) |
|
|
|
|
|
|
|
|
|
x = self.decoder_pred(x) |
|
|
|
|
|
x = x[:, 1:, :] |
|
|
|
return x, decoder_feat |
|
|
|
def forward_loss(self, imgs, pred, mask): |
|
""" |
|
imgs: [N, 3, H, W] |
|
pred: [N, L, p*p*3] |
|
mask: [N, L], 0 is keep, 1 is remove, |
|
""" |
|
target = self.patchify(imgs) |
|
if self.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
target = (target - mean) / (var + 1.e-6)**.5 |
|
|
|
loss = (pred - target) ** 2 |
|
loss = loss.mean(dim=-1) |
|
|
|
loss = (loss * mask).sum() / mask.sum() |
|
return loss |
|
|
|
def embed_text(self, text): |
|
batch, device = text.shape[0], text.device |
|
|
|
seq = text.shape[1] |
|
|
|
text_tokens = self.token_emb(text) |
|
|
|
|
|
text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch) |
|
text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) |
|
|
|
|
|
|
|
cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') |
|
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) |
|
|
|
|
|
for attn_ff in self.unimodal_layers: |
|
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) |
|
|
|
if self.need_uni_2_mul_proj: |
|
text_tokens = self.uni_2_mul_proj(text_tokens) |
|
|
|
|
|
text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] |
|
return text_tokens |
|
|
|
|
|
|
|
def forward(self, imgs, caption_ids=None, attention_mask=None, mask_ratio=0.75, |
|
freeze_bert=False, teacher_forcing=False, caption_only=False, |
|
encoder_only=False, word_weights=None, syn_count=None): |
|
latent, mask, ids_restore, ids_keep = self.forward_encoder(imgs, mask_ratio) |
|
|
|
if not caption_only: |
|
pred, decoder_feat = self.forward_decoder(latent, ids_restore) |
|
mae_loss = self.forward_loss(imgs, pred, mask) |
|
else: |
|
mae_loss = 0. |
|
|
|
if self.latent_projector_layer != 0: |
|
latent = self.latent_projector(latent) |
|
|
|
|
|
|
|
text, labels = caption_ids[:, :-1], caption_ids[:, 1:] |
|
|
|
seq = text.shape[1] |
|
text_tokens = self.embed_text(text) |
|
|
|
|
|
|
|
cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') |
|
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) |
|
unimodal_text_tokens = text_tokens |
|
if not self.less_u: |
|
for attn_ff, cross_attn in self.multimodal_layers: |
|
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask[:, :-1, :-1]) |
|
text_tokens = cross_attn(text_tokens, latent) |
|
else: |
|
|
|
for cross_attn1, cross_attn2 in self.multimodal_layers: |
|
text_tokens = cross_attn1(text_tokens, latent) |
|
text_tokens = cross_attn2(text_tokens, latent) |
|
|
|
logits = self.to_logits(text_tokens) |
|
logits = logits.reshape(-1, len(self.tokenizer.vocab)) |
|
labels = labels.reshape(-1) |
|
|
|
caption_loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id,) |
|
|
|
|
|
return mae_loss, caption_loss, None |
|
|
|
|
|
|
|
def mae_vit_small_patch16_dec512d8b(**kwargs): |
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
|
|
def mae_vit_base_patch16_dec512d8b(**kwargs): |
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
def mae_vit_large_patch16_dec512d8b(**kwargs): |
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
def mae_vit_huge_patch14_dec512d8b(**kwargs): |
|
model = MaskedAutoencoderViT( |
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
|
|
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b |
|
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b |
|
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b |
|
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b |
|
|
|
|
|
|
|
|
|
|
|
|