| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp |
| |
|
| | from util.pos_embed import get_2d_sincos_pos_embed |
| |
|
| | from taming.models.vqgan import VQModel |
| | from omegaconf import OmegaConf |
| | import numpy as np |
| | import scipy.stats as stats |
| | from compressai.entropy_models import EntropyBottleneck |
| | from compressai.layers import conv3x3, subpel_conv3x3 |
| | import math |
| | from torch import Tensor |
| | from einops import rearrange, repeat |
| | import torch.nn.functional as F |
| | import torchac |
| | from typing import Any, Callable, List, Optional, Tuple, Union |
| |
|
| | SCALES_MIN = 0.11 |
| | SCALES_MAX = 256 |
| | SCALES_LEVELS = 64 |
| | def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS): |
| | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) |
| |
|
| | def ste_round(x: Tensor) -> Tensor: |
| | return torch.round(x) - x.detach() + x |
| |
|
| | def conv(in_channels, out_channels, kernel_size=5, stride=2): |
| | return nn.Conv2d( |
| | in_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=kernel_size // 2, |
| | ) |
| |
|
| | def mask_by_random_topk(mask_len, probs, temperature=1.0): |
| | mask_len = mask_len.squeeze() |
| | |
| | confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda() |
| | sorted_confidence, _ = torch.sort(confidence, axis=-1) |
| | |
| | cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()] |
| | |
| | masking = (confidence <= cut_off) |
| | return masking |
| |
|
| | def adjust_mask_and_drop_embeddings(token_keep_mask): |
| | """ |
| | Adjusts the token_keep_mask to the nearest square number of True values by randomly setting |
| | some of them to False, and then applies this adjusted mask to input_embeddings. |
| | |
| | Parameters: |
| | - input_embeddings: Tensor, The embeddings tensor. |
| | - token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep. |
| | |
| | Returns: |
| | - Tensor, Adjusted input embeddings after applying the modified token_keep_mask. |
| | """ |
| | |
| | non_zero_indices = token_keep_mask.nonzero(as_tuple=True) |
| | |
| | non_zero_count = non_zero_indices[0].size(0) |
| | |
| | next_square = math.floor(math.sqrt(non_zero_count))**2 |
| | |
| | remove_count = non_zero_count - next_square |
| | if remove_count > 0: |
| | |
| | permuted_indices = torch.randperm(non_zero_count)[:remove_count] |
| | for idx in permuted_indices: |
| | token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False |
| | |
| | |
| |
|
| | return token_keep_mask |
| |
|
| |
|
| | class FactorizedEntropyModel(EntropyBottleneck): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]: |
| | if training is None: |
| | training = self.training |
| |
|
| | |
| | shape = x.size() |
| |
|
| | |
| | means = self._get_medians() |
| | |
| | |
| | |
| | outputs = self.quantize( |
| | x, "dequantize", means.long() |
| | ) |
| |
|
| | if not torch.jit.is_scripting(): |
| | likelihood = self._likelihood(outputs) |
| | if self.use_likelihood_bound: |
| | likelihood = self.likelihood_lower_bound(likelihood) |
| | else: |
| | raise NotImplementedError("TorchScript is not yet supported") |
| |
|
| | return outputs, likelihood |
| | |
| | def compress(self, x): |
| | |
| | indexes = self._build_indexes(x.size()) |
| | |
| | medians = self._get_medians().detach() |
| | |
| | medians = medians.expand_as(x) |
| | |
| | return super().compress(x, indexes, medians) |
| |
|
| | def decompress(self, strings, size): |
| | |
| | output_size = (len(strings), 1, *size) |
| | |
| | indexes = self._build_indexes(output_size).to(self._quantized_cdf.device) |
| | |
| | medians = self._extend_ndims(self._get_medians().detach(), len(size)) |
| | medians = medians.expand(len(strings), 1, *([-1] * len(size))) |
| | |
| | return super().decompress(strings, indexes, medians.dtype, medians) |
| | |
| | def _preprocess(self, x): |
| | x = x.permute(0, 2, 3, 1).contiguous() |
| | return x |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | |
| | self.scale = qk_scale or head_dim ** -0.5 |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x): |
| | B, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | with torch.cuda.amp.autocast(enabled=False): |
| | attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale |
| |
|
| | attn = attn - torch.max(attn, dim=-1, keepdim=True)[0] |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | |
| | return x, attn |
| |
|
| |
|
| | class Block(nn.Module): |
| |
|
| | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
| | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| | super().__init__() |
| | self.norm1 = norm_layer(dim) |
| | self.attn = Attention( |
| | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
| | |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | self.norm2 = norm_layer(dim) |
| | mlp_hidden_dim = int(dim * mlp_ratio) |
| | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| |
|
| | def forward(self, x, return_attention=False): |
| | if return_attention: |
| | _, attn = self.attn(self.norm1(x)) |
| | return attn |
| | else: |
| | y, _ = self.attn(self.norm1(x)) |
| | x = x + self.drop_path(y) |
| | x = x + self.drop_path(self.mlp(self.norm2(x))) |
| | return x |
| |
|
| |
|
| | class LabelSmoothingCrossEntropy(nn.Module): |
| | """ NLL loss with label smoothing. |
| | """ |
| | def __init__(self, smoothing=0.1): |
| | super(LabelSmoothingCrossEntropy, self).__init__() |
| | assert smoothing < 1.0 |
| | self.smoothing = smoothing |
| | self.confidence = 1. - smoothing |
| |
|
| | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| | logprobs = torch.nn.functional.log_softmax(x, dim=-1) |
| | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) |
| | nll_loss = nll_loss.squeeze(1) |
| | smooth_loss = -logprobs.mean(dim=-1) |
| | loss = self.confidence * nll_loss + self.smoothing * smooth_loss |
| | return loss |
| |
|
| |
|
| | class BertEmbeddings(nn.Module): |
| | """Construct the embeddings from word, position and token_type embeddings.""" |
| |
|
| | def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1): |
| | super().__init__() |
| | self.word_embeddings = nn.Embedding(vocab_size, hidden_size) |
| | self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) |
| |
|
| | |
| | |
| | self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6) |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) |
| | |
| | |
| |
|
| | torch.nn.init.normal_(self.word_embeddings.weight, std=.02) |
| | torch.nn.init.normal_(self.position_embeddings.weight, std=.02) |
| |
|
| | def forward( |
| | self, input_ids |
| | ): |
| | input_shape = input_ids.size() |
| |
|
| | seq_length = input_shape[1] |
| |
|
| | position_ids = self.position_ids[:, :seq_length] |
| |
|
| | inputs_embeds = self.word_embeddings(input_ids) |
| |
|
| | position_embeddings = self.position_embeddings(position_ids) |
| | embeddings = inputs_embeds + position_embeddings |
| |
|
| | embeddings = self.LayerNorm(embeddings) |
| | embeddings = self.dropout(embeddings) |
| | return embeddings |
| |
|
| |
|
| | class MlmLayer(nn.Module): |
| |
|
| | def __init__(self, feat_emb_dim, word_emb_dim, vocab_size): |
| | super().__init__() |
| | self.fc = nn.Linear(feat_emb_dim, word_emb_dim) |
| | self.gelu = nn.GELU() |
| | self.ln = nn.LayerNorm(word_emb_dim) |
| | self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size)) |
| |
|
| | def forward(self, x, word_embeddings): |
| | mlm_hidden = self.fc(x) |
| | mlm_hidden = self.gelu(mlm_hidden) |
| | mlm_hidden = self.ln(mlm_hidden) |
| | word_embeddings = word_embeddings.transpose(0, 1) |
| | logits = torch.matmul(mlm_hidden, word_embeddings) |
| | logits = logits + self.bias |
| | return logits |
| |
|
| |
|
| | class MaskedGenerativeEncoderViT(nn.Module): |
| | """ Masked Autoencoder with VisionTransformer backbone |
| | """ |
| | def __init__(self, img_size=256, patch_size=16, in_chans=3, |
| | embed_dim=1024, depth=24, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, |
| | mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25, |
| | vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'): |
| | super().__init__() |
| |
|
| | |
| | |
| | config = OmegaConf.load('config/vqgan.yaml').model |
| | self.vqgan = VQModel(ddconfig=config.params.ddconfig, |
| | n_embed=config.params.n_embed, |
| | embed_dim=config.params.embed_dim, |
| | ckpt_path=vqgan_ckpt_path) |
| | for param in self.vqgan.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.codebook_size = config.params.n_embed |
| | vocab_size = self.codebook_size + 1000 + 1 |
| | self.fake_class_label = self.codebook_size + 1100 - 1024 |
| | self.mask_token_label = vocab_size - 1 |
| | self.token_emb = BertEmbeddings(vocab_size=vocab_size, |
| | hidden_size=embed_dim, |
| | max_position_embeddings=img_size +1, |
| | |
| | dropout=0.1) |
| |
|
| | |
| | self.mask_ratio_min = mask_ratio_min |
| | self.mask_ratio_max = mask_ratio_max |
| | |
| | |
| | |
| |
|
| | |
| | |
| | dropout_rate = 0.1 |
| | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
| | num_patches = self.patch_embed.num_patches |
| |
|
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
| |
|
| | self.blocks = nn.ModuleList([ |
| | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, |
| | drop=dropout_rate, attn_drop=dropout_rate) |
| | for i in range(depth)]) |
| | self.norm = norm_layer(embed_dim) |
| | |
| |
|
| | |
| | |
| | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
| |
|
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| | self.pad_with_cls_token = True |
| |
|
| | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) |
| | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) |
| |
|
| | self.decoder_blocks = nn.ModuleList([ |
| | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, |
| | drop=dropout_rate, attn_drop=dropout_rate) |
| | for i in range(decoder_depth)]) |
| |
|
| | self.decoder_norm = norm_layer(decoder_embed_dim) |
| | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) |
| | |
| |
|
| | |
| | |
| | self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size) |
| |
|
| | self.norm_pix_loss = norm_pix_loss |
| |
|
| | self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1) |
| | |
| | self.entropy_bottleneck = FactorizedEntropyModel(1) |
| |
|
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | |
| | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) |
| | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| |
|
| | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) |
| | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
| |
|
| | |
| | w = self.patch_embed.proj.weight.data |
| | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| |
|
| | |
| | torch.nn.init.normal_(self.cls_token, std=.02) |
| | torch.nn.init.normal_(self.mask_token, std=.02) |
| | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def random_sample_mask_rate(self): |
| | |
| | random_sample = 1 - torch.rand(1) |
| | |
| | mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min) |
| | return mask_rate.item() |
| |
|
| | def get_cdf_token_mask(self, token_all_mask): |
| | bsz, seq_len = token_all_mask.size() |
| | |
| | dist_normal = torch.distributions.Normal(0, 2) |
| | cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1)) |
| | cdf_mask_token = (cdf_mask_token - .5) * 2 |
| | cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp', |
| | b=bsz, s=seq_len) |
| | |
| | cdf_mask_token = F.pad(cdf_mask_token, (1, 0)) |
| | return cdf_mask_token |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def pre_encoding(self, x, is_training=False, manual_mask_rate=None): |
| | """ |
| | input: x: (B, 3, H, W) |
| | """ |
| | |
| | with torch.no_grad(): |
| | z_q, _, token_tuple = self.vqgan.encode(x) |
| | |
| | _, _, token_indices = token_tuple |
| | token_indices = token_indices.reshape(z_q.size(0), -1) |
| | gt_indices = token_indices.clone().detach().long() |
| |
|
| | |
| | bsz, seq_len = token_indices.size() |
| | mask_ratio_min = self.mask_ratio_min |
| |
|
| | if is_training: |
| | |
| | mask_rate = self.random_sample_mask_rate() |
| | num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min)) |
| | else: |
| | num_dropped_tokens = 0 |
| | if manual_mask_rate is not None: |
| | mask_rate = manual_mask_rate |
| | else: |
| | raise ValueError("mask_rate should be provided for inference!") |
| | |
| | num_masked_tokens = int(np.ceil(seq_len * mask_rate)) |
| | mask_ratio = num_masked_tokens / seq_len |
| | |
| | while True: |
| | noise = torch.rand(bsz, seq_len, device=x.device) |
| | sorted_noise, _ = torch.sort(noise, dim=1) |
| | if num_dropped_tokens > 0: |
| | cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens] |
| | else: |
| | cutoff_drop = torch.zeros((bsz, 1), device=x.device) |
| | cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens] |
| | token_drop_mask = (noise <= cutoff_drop).float() |
| | token_all_mask = (noise <= cutoff_mask).float() |
| | if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens: |
| | break |
| | else: |
| | print("Rerandom the noise!") |
| |
|
| | |
| | unmasked_pos = token_all_mask == 0 |
| | unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) |
| |
|
| | return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio |
| |
|
| | def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask): |
| | bsz, seq_len = gt_indices.size() |
| | padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label) |
| | |
| | |
| | unmasked_token_counter = [0 for _ in range(bsz)] |
| | |
| | for b in range(bsz): |
| | for idx in range(seq_len): |
| | |
| | if (token_all_mask[b, idx] == 0): |
| | |
| | padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]] |
| | |
| | unmasked_token_counter[b] += 1 |
| | |
| | token_indices = padded_token_indices |
| | |
| | |
| | token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1) |
| | token_indices[:, 0] = self.fake_class_label |
| | |
| | |
| | |
| | token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1) |
| | token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1) |
| | token_indices = token_indices.long() |
| |
|
| | |
| | |
| | input_embeddings = self.token_emb(token_indices) |
| | |
| | bsz, seq_len, emb_dim = input_embeddings.shape |
| |
|
| | |
| | token_keep_mask = 1 - token_drop_mask |
| | input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim) |
| | |
| | |
| | |
| | x = input_embeddings_after_drop |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| | |
| |
|
| | return x, token_indices, token_all_mask, token_drop_mask |
| |
|
| | def forward_decoding(self, x, token_drop_mask, token_all_mask): |
| | """ |
| | x: output x of forward_encoder() |
| | token_drop_mask: positions for dropped tokens |
| | token_all_mask: positions for masked tokens |
| | """ |
| | |
| | |
| | x = self.decoder_embed(x) |
| |
|
| | |
| | |
| | if self.pad_with_cls_token: |
| | mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1) |
| | else: |
| | mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1) |
| |
|
| | |
| | |
| | x_after_pad = mask_tokens.clone() |
| | x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) |
| | |
| | x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) |
| |
|
| | |
| | x = x_after_pad + self.decoder_pos_embed_learned |
| |
|
| | |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| |
|
| | x = self.decoder_norm(x) |
| |
|
| | word_embeddings = self.token_emb.word_embeddings.weight.data.detach() |
| | logits = self.mlm_layer(x, word_embeddings) |
| | |
| |
|
| | return logits |
| |
|
| | def forward_loss(self, gt_indices, logits, mask): |
| | bsz, seq_len = gt_indices.size() |
| | |
| | loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len)) |
| | loss = loss.reshape(bsz, seq_len) |
| | loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() |
| | return loss |
| |
|
| | def cal_lmbda(self, mask_ratio, A=5e-1, B=8): |
| | lmbda = A * torch.exp(B * (1 - mask_ratio)) |
| | return lmbda |
| | |
| | def cal_loss(self, logits, gt_indices, mask, mask_ratio): |
| | mask_ratio = torch.tensor(mask_ratio) |
| | |
| | task_loss = self.forward_loss(gt_indices, logits, mask) |
| | lmbda = self.cal_lmbda(mask_ratio) |
| | |
| | return task_loss, lmbda |
| |
|
| | def forward(self, imgs, is_training=False, manual_mask_rate=None): |
| | |
| | gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio = self.pre_encoding(imgs, is_training, manual_mask_rate) |
| | latent = latent.unsqueeze(1) |
| |
|
| | latent_hat, latent_likelihoods = self.entropy_bottleneck(latent) |
| | |
| | |
| | cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu() |
| | sym = (token_all_mask.short() + 1).cpu() |
| | bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True) |
| | mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=16, w=16).unsqueeze(1) |
| |
|
| | |
| | decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token) |
| | decoded_mask = (decoded_sym - 1).to(device=imgs.device) |
| | latent_hat = latent_hat.squeeze(1) |
| | x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask) |
| | logits = self.forward_decoding(x, token_drop_mask, token_all_mask) |
| | |
| | task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio) |
| | return_dict = { |
| | 'logits': logits, |
| | 'likelihoods': latent_likelihoods, |
| | 'task_loss': task_loss, |
| | 'token_indices': token_indices, |
| | 'token_all_mask': token_all_mask, |
| | 'bs_mask_token': bs_mask_token, |
| | 'mask_ratio': mask_ratio, |
| | 'lambda': lmbda, |
| | 'mask_vis': 1 - mask_vis, |
| | } |
| | return return_dict |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def gen_img(self, logits, token_all_mask, token_indices, num_iter=12, choice_temperature=4.5): |
| | """ |
| | generated image at inference |
| | seed: random seed |
| | logits: predicted logits by model decoder |
| | token_all_mask: mask token indices |
| | token_indices: token indices of the input image after the vq tokenizer |
| | num_iter: number of iterations for sampling |
| | choice_temperature: temperature for sampling |
| | """ |
| | |
| | |
| | bsz = logits.size(0) |
| | codebook_emb_dim = 256 |
| | codebook_size = 1024 |
| | mask_token_id = self.mask_token_label |
| | _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf |
| | unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float() |
| | for step in range(num_iter): |
| | if step == 0: |
| | cur_ids = token_indices.clone().long() |
| | cur_ids = cur_ids[:, 1:] |
| | logits = logits[:, 1:, :codebook_size] |
| | |
| | |
| | sample_dist = torch.distributions.categorical.Categorical(logits=logits) |
| | sampled_ids = sample_dist.sample() |
| | |
| | |
| | |
| | |
| | unknown_map = (cur_ids == mask_token_id) |
| | sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) |
| | |
| | |
| | ratio = 1. * (step + 1) / num_iter |
| | mask_ratio = np.cos(math.pi / 2. * ratio) |
| |
|
| | |
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| | selected_probs = torch.squeeze( |
| | torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) |
| |
|
| | selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float() |
| | unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda() |
| | mask_ratio = torch.tensor(mask_ratio).cuda() |
| | |
| | mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() |
| | |
| | |
| | mask_len = torch.maximum(torch.Tensor([1]).cuda(), |
| | torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len)) |
| |
|
| | |
| | masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio)) |
| | |
| | token_indices = torch.where(masking, mask_token_id, sampled_ids) |
| | else: |
| | cur_ids = token_indices.clone().long() |
| | token_indices = torch.cat( |
| | [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1) |
| | token_indices[:, 0] = self.fake_class_label |
| | token_indices = token_indices.long() |
| | token_all_mask = token_indices == mask_token_id |
| |
|
| | token_drop_mask = torch.zeros_like(token_indices) |
| |
|
| | |
| | input_embeddings = self.token_emb(token_indices) |
| |
|
| | |
| | x = input_embeddings |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| |
|
| | |
| | logits = self.forward_decoding(x, token_drop_mask, token_all_mask) |
| | logits = logits[:, 1:, :codebook_size] |
| |
|
| | |
| | |
| | |
| | sample_dist = torch.distributions.categorical.Categorical(logits=logits) |
| | sampled_ids = sample_dist.sample() |
| |
|
| | |
| | |
| | |
| | |
| | unknown_map = (cur_ids == mask_token_id) |
| | sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) |
| | |
| | |
| | ratio = 1. * (step + 1) / num_iter |
| |
|
| | mask_ratio = np.cos(math.pi / 2. * ratio) |
| |
|
| | |
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| | selected_probs = torch.squeeze( |
| | torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) |
| |
|
| | selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float() |
| | unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda() |
| | mask_ratio = torch.tensor(mask_ratio).cuda() |
| | mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() |
| | |
| | |
| | mask_len = torch.maximum(torch.Tensor([1]).cuda(), |
| | torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len)) |
| |
|
| | |
| | masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio)) |
| | |
| | token_indices = torch.where(masking, mask_token_id, sampled_ids) |
| | |
| | |
| | z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, 16, 16, codebook_emb_dim)) |
| | gen_images = self.vqgan.decode(z_q) |
| | return gen_images |
| |
|
| |
|
| | def mage_vit_base_patch16(**kwargs): |
| | model = MaskedGenerativeEncoderViT( |
| | patch_size=16, embed_dim=768, depth=12, num_heads=12, |
| | decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mage_vit_large_patch16(**kwargs): |
| | model = MaskedGenerativeEncoderViT( |
| | patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
| | decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|