Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| from timm.models.vision_transformer import Block | |
| from functools import partial | |
| class MARDecoder(nn.Module): | |
| """ Masked Autoencoder with VisionTransformer backbone | |
| """ | |
| def __init__(self, img_size=256, vae_stride=16, | |
| patch_size=1, | |
| # encoder_embed_dim=1024, | |
| decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, | |
| mlp_ratio=4., | |
| attn_dropout=0.1, | |
| proj_dropout=0.1, | |
| buffer_size=64, | |
| grad_checkpointing=False, | |
| ): | |
| super().__init__() | |
| # -------------------------------------------------------------------------- | |
| # VAE | |
| self.img_size = img_size | |
| self.vae_stride = vae_stride | |
| self.seq_h = self.seq_w = img_size // vae_stride // patch_size | |
| self.seq_len = self.seq_h * self.seq_w | |
| self.grad_checkpointing = grad_checkpointing | |
| # -------------------------------------------------------------------------- | |
| # MAR decoder specifics | |
| self.buffer_size = buffer_size | |
| # self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
| self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) | |
| self.decoder_blocks = nn.ModuleList([ | |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) | |
| self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6) | |
| self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # parameters | |
| torch.nn.init.normal_(self.mask_token, std=.02) | |
| torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) | |
| torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, x, mask): | |
| # x = self.decoder_embed(x) | |
| mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) | |
| # pad mask tokens | |
| mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) | |
| x_after_pad = mask_tokens.clone() | |
| x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) | |
| # decoder position embedding | |
| x = x_after_pad + self.decoder_pos_embed_learned | |
| # apply Transformer blocks | |
| if self.grad_checkpointing and not torch.jit.is_scripting(): | |
| for block in self.decoder_blocks: | |
| x = checkpoint(block, x) | |
| else: | |
| for block in self.decoder_blocks: | |
| x = block(x) | |
| x = self.decoder_norm(x) | |
| x = x[:, self.buffer_size:] | |
| x = x + self.diffusion_pos_embed_learned | |
| return x | |
| def gradient_checkpointing_enable(self): | |
| self.grad_checkpointing = True | |
| def gradient_checkpointing_disable(self): | |
| self.grad_checkpointing = False | |