| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | import logging |
| | from dataclasses import dataclass |
| | from functools import partial |
| |
|
| | from timm.models.vision_transformer import PatchEmbed, Block |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | import numpy as np |
| |
|
| | from fairseq.dataclass import FairseqDataclass |
| | from fairseq.models import BaseFairseqModel, register_model |
| | from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer |
| |
|
| | try: |
| | from apex.normalization import FusedLayerNorm |
| | except: |
| | FusedLayerNorm = nn.LayerNorm |
| |
|
| | import torch.nn.functional as F |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class MaeConfig(FairseqDataclass): |
| | input_size: int = 224 |
| | in_chans: int = 3 |
| | patch_size: int = 16 |
| | embed_dim: int = 768 |
| | depth: int = 12 |
| | num_heads: int = 12 |
| | decoder_embed_dim: int = 512 |
| | decoder_depth: int = 8 |
| | decoder_num_heads: int = 16 |
| | mlp_ratio: int = 4 |
| | norm_eps: float = 1e-6 |
| |
|
| | drop_path_rate: float = 0.0 |
| |
|
| | mask_ratio: float = 0.75 |
| | norm_pix_loss: bool = True |
| |
|
| | w2v_block: bool = False |
| | alt_block: bool = False |
| | alt_block2: bool = False |
| | alt_attention: bool = False |
| | block_dropout: float = 0 |
| | attention_dropout: float = 0 |
| | activation_dropout: float = 0 |
| | layer_norm_first: bool = False |
| |
|
| | fused_ln: bool = True |
| | end_of_block_targets: bool = True |
| |
|
| | no_decoder_embed: bool = False |
| | no_decoder_pos_embed: bool = False |
| | mask_noise_std: float = 0 |
| |
|
| | single_qkv: bool = False |
| | use_rel_pos_bias: bool = False |
| | no_cls: bool = False |
| |
|
| |
|
| | def modify_relative_position_bias(orig_bias, bsz, mask): |
| | if mask is None: |
| | return orig_bias.unsqueeze(0).repeat( |
| | bsz, 1, 1, 1 |
| | ) |
| | heads, max_seq_len, max_seq_len = orig_bias.shape |
| | mask_for_rel_pos_bias = torch.cat( |
| | (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1 |
| | ).bool() |
| | unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias |
| | unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat( |
| | 1, heads, 1 |
| | ) |
| | b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat( |
| | bsz, 1, 1, 1 |
| | ) |
| | b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( |
| | unmasked_for_rel_pos_bias.unsqueeze(-1) |
| | ) |
| | b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len) |
| | new_len = b_t_t_rel_pos_bias.size(-2) |
| | b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( |
| | unmasked_for_rel_pos_bias.unsqueeze(-2) |
| | ) |
| | b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len) |
| | return b_t_t_rel_pos_bias |
| |
|
| |
|
| | class AltBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads, |
| | mlp_ratio=4.0, |
| | qkv_bias=False, |
| | qk_scale=None, |
| | drop=0.0, |
| | attn_drop=0.0, |
| | drop_path=0.0, |
| | act_layer=nn.GELU, |
| | norm_layer=nn.LayerNorm, |
| | layer_norm_first=True, |
| | ffn_targets=False, |
| | use_rel_pos_bias=False, |
| | window_size=None, |
| | alt_attention=False, |
| | ): |
| | super().__init__() |
| |
|
| | self.layer_norm_first = layer_norm_first |
| | self.ffn_targets = ffn_targets |
| |
|
| | from timm.models.vision_transformer import Attention, DropPath, Mlp |
| |
|
| | self.norm1 = norm_layer(dim) |
| | self.use_rel_pos_bias = use_rel_pos_bias |
| | if use_rel_pos_bias: |
| | self.attn = AltAttention( |
| | dim, |
| | num_heads=num_heads, |
| | qkv_bias=qkv_bias, |
| | qk_scale=qk_scale, |
| | attn_drop=attn_drop, |
| | proj_drop=drop, |
| | window_size=window_size, |
| | ) |
| | else: |
| | if alt_attention: |
| | from .multi.modules import AltAttention as AltAttention2 |
| | self.attn = AltAttention2( |
| | dim, |
| | num_heads=num_heads, |
| | qkv_bias=qkv_bias, |
| | qk_scale=qk_scale, |
| | attn_drop=attn_drop, |
| | proj_drop=drop, |
| | ) |
| | else: |
| | self.attn = Attention( |
| | dim, |
| | num_heads=num_heads, |
| | qkv_bias=qkv_bias, |
| | qk_scale=qk_scale, |
| | attn_drop=attn_drop, |
| | proj_drop=drop, |
| | ) |
| | |
| | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| | self.norm2 = norm_layer(dim) |
| | mlp_hidden_dim = int(dim * mlp_ratio) |
| | self.mlp = Mlp( |
| | in_features=dim, |
| | hidden_features=mlp_hidden_dim, |
| | act_layer=act_layer, |
| | drop=drop, |
| | ) |
| |
|
| | def forward(self, x, rel_pos_bias=None, pos_mask=None): |
| | if self.layer_norm_first: |
| | if self.use_rel_pos_bias: |
| | x = x + self.drop_path( |
| | self.attn( |
| | self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask |
| | ) |
| | ) |
| | else: |
| | x = x + self.drop_path(self.attn(self.norm1(x))) |
| | t = self.mlp(self.norm2(x)) |
| | x = x + self.drop_path(t) |
| | if not self.ffn_targets: |
| | t = x |
| | return x, t |
| | else: |
| | if self.use_rel_pos_bias: |
| | x = x + self.drop_path( |
| | self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask) |
| | ) |
| | else: |
| | x = x + self.drop_path(self.attn(x)) |
| | r = x = self.norm1(x) |
| | x = self.mlp(x) |
| | t = x |
| | x = self.norm2(r + self.drop_path(x)) |
| | if not self.ffn_targets: |
| | t = x |
| | return x, t |
| |
|
| |
|
| | class AltAttention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | num_heads=8, |
| | qkv_bias=True, |
| | qk_scale=None, |
| | attn_drop=0.0, |
| | proj_drop=0.0, |
| | window_size=None, |
| | attn_head_dim=None, |
| | ): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | if attn_head_dim is not None: |
| | head_dim = attn_head_dim |
| | all_head_dim = head_dim * self.num_heads |
| | self.scale = qk_scale or head_dim ** -0.5 |
| |
|
| | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) |
| | if qkv_bias: |
| | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| | else: |
| | self.q_bias = None |
| | self.v_bias = None |
| |
|
| | if window_size: |
| | self.window_size = window_size |
| | self.num_relative_distance = (2 * window_size[0] - 1) * ( |
| | 2 * window_size[1] - 1 |
| | ) + 3 |
| | self.relative_position_bias_table = nn.Parameter( |
| | torch.zeros(self.num_relative_distance, num_heads) |
| | ) |
| | |
| |
|
| | |
| | coords_h = torch.arange(window_size[0]) |
| | coords_w = torch.arange(window_size[1]) |
| | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
| | coords_flatten = torch.flatten(coords, 1) |
| | relative_coords = ( |
| | coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| | ) |
| | relative_coords = relative_coords.permute( |
| | 1, 2, 0 |
| | ).contiguous() |
| | relative_coords[:, :, 0] += window_size[0] - 1 |
| | relative_coords[:, :, 1] += window_size[1] - 1 |
| | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 |
| | relative_position_index = torch.zeros( |
| | size=(window_size[0] * window_size[1] + 1,) * 2, |
| | dtype=relative_coords.dtype, |
| | ) |
| | relative_position_index[1:, 1:] = relative_coords.sum(-1) |
| | relative_position_index[0, 0:] = self.num_relative_distance - 3 |
| | relative_position_index[0:, 0] = self.num_relative_distance - 2 |
| | relative_position_index[0, 0] = self.num_relative_distance - 1 |
| |
|
| | self.register_buffer("relative_position_index", relative_position_index) |
| | else: |
| | self.window_size = None |
| | self.relative_position_bias_table = None |
| | self.relative_position_index = None |
| |
|
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(all_head_dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x, rel_pos_bias=None, pos_mask=None): |
| | B, N, C = x.shape |
| | qkv_bias = None |
| | if self.q_bias is not None: |
| | qkv_bias = torch.cat( |
| | ( |
| | self.q_bias, |
| | torch.zeros_like(self.v_bias, requires_grad=False), |
| | self.v_bias, |
| | ) |
| | ) |
| | |
| | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) |
| | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
| | q, k, v = ( |
| | qkv[0], |
| | qkv[1], |
| | qkv[2], |
| | ) |
| |
|
| | q = q * self.scale |
| | attn = q @ k.transpose(-2, -1) |
| |
|
| | if self.relative_position_bias_table is not None: |
| | relative_position_bias = self.relative_position_bias_table[ |
| | self.relative_position_index.view(-1) |
| | ].view( |
| | self.window_size[0] * self.window_size[1] + 1, |
| | self.window_size[0] * self.window_size[1] + 1, |
| | -1, |
| | ) |
| | relative_position_bias = relative_position_bias.permute( |
| | 2, 0, 1 |
| | ).contiguous() |
| | attn = attn + modify_relative_position_bias( |
| | relative_position_bias, x.size(0), pos_mask |
| | ) |
| |
|
| | if rel_pos_bias is not None: |
| | attn = attn + rel_pos_bias |
| |
|
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| |
|
| | class RelativePositionBias(nn.Module): |
| | def __init__(self, window_size, num_heads): |
| | super().__init__() |
| | self.window_size = window_size |
| | self.num_relative_distance = (2 * window_size[0] - 1) * ( |
| | 2 * window_size[1] - 1 |
| | ) + 3 |
| | self.relative_position_bias_table = nn.Parameter( |
| | torch.zeros(self.num_relative_distance, num_heads) |
| | ) |
| |
|
| | |
| | coords_h = torch.arange(window_size[0]) |
| | coords_w = torch.arange(window_size[1]) |
| | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
| | coords_flatten = torch.flatten(coords, 1) |
| | relative_coords = ( |
| | coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| | ) |
| | relative_coords = relative_coords.permute( |
| | 1, 2, 0 |
| | ).contiguous() |
| | relative_coords[:, :, 0] += window_size[0] - 1 |
| | relative_coords[:, :, 1] += window_size[1] - 1 |
| | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 |
| | relative_position_index = torch.zeros( |
| | size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype |
| | ) |
| | relative_position_index[1:, 1:] = relative_coords.sum(-1) |
| | relative_position_index[0, 0:] = self.num_relative_distance - 3 |
| | relative_position_index[0:, 0] = self.num_relative_distance - 2 |
| | relative_position_index[0, 0] = self.num_relative_distance - 1 |
| |
|
| | self.register_buffer("relative_position_index", relative_position_index) |
| |
|
| | def forward(self): |
| | relative_position_bias = self.relative_position_bias_table[ |
| | self.relative_position_index.view(-1) |
| | ].view( |
| | self.window_size[0] * self.window_size[1] + 1, |
| | self.window_size[0] * self.window_size[1] + 1, |
| | -1, |
| | ) |
| | return relative_position_bias.permute(2, 0, 1).contiguous() |
| |
|
| |
|
| | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
| | """ |
| | grid_size: int of the grid height and width |
| | return: |
| | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| | """ |
| | grid_h = np.arange(grid_size, dtype=np.float32) |
| | grid_w = np.arange(grid_size, dtype=np.float32) |
| | grid = np.meshgrid(grid_w, grid_h) |
| | grid = np.stack(grid, axis=0) |
| |
|
| | grid = grid.reshape([2, 1, grid_size, grid_size]) |
| | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| | if cls_token: |
| | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| | return pos_embed |
| |
|
| |
|
| | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| | assert embed_dim % 2 == 0 |
| |
|
| | |
| | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
| |
|
| | emb = np.concatenate([emb_h, emb_w], axis=1) |
| | return emb |
| |
|
| |
|
| | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| | """ |
| | embed_dim: output dimension for each position |
| | pos: a list of positions to be encoded: size (M,) |
| | out: (M, D) |
| | """ |
| | assert embed_dim % 2 == 0 |
| | omega = np.arange(embed_dim // 2, dtype=np.float) |
| | omega /= embed_dim / 2.0 |
| | omega = 1.0 / 10000 ** omega |
| |
|
| | pos = pos.reshape(-1) |
| | out = np.einsum("m,d->md", pos, omega) |
| |
|
| | emb_sin = np.sin(out) |
| | emb_cos = np.cos(out) |
| |
|
| | emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| | return emb |
| |
|
| |
|
| | def interpolate_pos_embed(model, checkpoint_model): |
| | if "pos_embed" in checkpoint_model: |
| | pos_embed_checkpoint = checkpoint_model["pos_embed"] |
| | embedding_size = pos_embed_checkpoint.shape[-1] |
| | num_patches = model.patch_embed.num_patches |
| | num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
| | |
| | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| | |
| | new_size = int(num_patches ** 0.5) |
| | |
| | if orig_size != new_size: |
| | print( |
| | "Position interpolate from %dx%d to %dx%d" |
| | % (orig_size, orig_size, new_size, new_size) |
| | ) |
| | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| | |
| | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| | pos_tokens = pos_tokens.reshape( |
| | -1, orig_size, orig_size, embedding_size |
| | ).permute(0, 3, 1, 2) |
| | pos_tokens = torch.nn.functional.interpolate( |
| | pos_tokens, |
| | size=(new_size, new_size), |
| | mode="bicubic", |
| | align_corners=False, |
| | ) |
| | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| | checkpoint_model["pos_embed"] = new_pos_embed |
| |
|
| |
|
| | @register_model("mae", dataclass=MaeConfig) |
| | class MaeModel(BaseFairseqModel): |
| | def __init__(self, cfg: MaeConfig): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | self.mask_ratio = cfg.mask_ratio |
| |
|
| | |
| | |
| | self.patch_embed = PatchEmbed( |
| | cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim |
| | ) |
| | num_patches = self.patch_embed.num_patches |
| |
|
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None |
| | self.pos_embed = nn.Parameter( |
| | torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False |
| | ) |
| |
|
| | norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps) |
| |
|
| | dpr = [ |
| | x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth) |
| | ] |
| |
|
| | def make_block(drop_path): |
| | if cfg.w2v_block: |
| | return TransformerSentenceEncoderLayer( |
| | embedding_dim=cfg.embed_dim, |
| | ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio, |
| | num_attention_heads=cfg.num_heads, |
| | dropout=cfg.block_dropout, |
| | attention_dropout=cfg.attention_dropout, |
| | activation_dropout=cfg.activation_dropout, |
| | activation_fn="gelu", |
| | layer_norm_first=cfg.layer_norm_first, |
| | drop_path=drop_path, |
| | norm_eps=1e-6, |
| | single_qkv=cfg.single_qkv, |
| | fused_ln=cfg.fused_ln, |
| | ) |
| | elif cfg.alt_block: |
| | window_size = ( |
| | cfg.input_size // self.patch_embed.patch_size[0], |
| | cfg.input_size // self.patch_embed.patch_size[1], |
| | ) |
| | return AltBlock( |
| | cfg.embed_dim, |
| | cfg.num_heads, |
| | cfg.mlp_ratio, |
| | qkv_bias=True, |
| | qk_scale=None, |
| | norm_layer=norm_layer, |
| | drop_path=drop_path, |
| | layer_norm_first=cfg.layer_norm_first, |
| | ffn_targets=not cfg.end_of_block_targets, |
| | use_rel_pos_bias=cfg.use_rel_pos_bias, |
| | window_size=window_size |
| | if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias) |
| | else None, |
| | alt_attention=cfg.alt_attention, |
| | ) |
| | elif cfg.alt_block2: |
| | from .multi.modules import AltBlock as AltBlock2 |
| | return AltBlock2( |
| | cfg.embed_dim, |
| | cfg.num_heads, |
| | cfg.mlp_ratio, |
| | qkv_bias=True, |
| | qk_scale=None, |
| | norm_layer=norm_layer, |
| | drop_path=drop_path, |
| | layer_norm_first=cfg.layer_norm_first, |
| | ffn_targets=not cfg.end_of_block_targets, |
| | ) |
| | else: |
| | return Block( |
| | cfg.embed_dim, |
| | cfg.num_heads, |
| | cfg.mlp_ratio, |
| | qkv_bias=True, |
| | qk_scale=None, |
| | norm_layer=norm_layer, |
| | drop_path=drop_path, |
| | ) |
| |
|
| | self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) |
| | self.norm = norm_layer(cfg.embed_dim) |
| | |
| |
|
| | |
| | |
| | self.decoder_embed = ( |
| | nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True) |
| | if not cfg.no_decoder_embed |
| | else None |
| | ) |
| |
|
| | self.mask_token = ( |
| | nn.Parameter( |
| | torch.zeros( |
| | 1, |
| | 1, |
| | cfg.decoder_embed_dim |
| | if not cfg.no_decoder_embed |
| | else cfg.embed_dim, |
| | ) |
| | ) |
| | if cfg.mask_noise_std <= 0 |
| | else None |
| | ) |
| |
|
| | self.decoder_pos_embed = ( |
| | nn.Parameter( |
| | torch.zeros( |
| | 1, |
| | num_patches + 1, |
| | cfg.decoder_embed_dim |
| | if not cfg.no_decoder_embed |
| | else cfg.embed_dim, |
| | ), |
| | requires_grad=False, |
| | ) |
| | if not cfg.no_decoder_pos_embed |
| | else None |
| | ) |
| |
|
| | self.decoder_blocks = nn.ModuleList( |
| | [ |
| | Block( |
| | cfg.decoder_embed_dim, |
| | cfg.decoder_num_heads, |
| | cfg.mlp_ratio, |
| | qkv_bias=True, |
| | qk_scale=None, |
| | norm_layer=norm_layer, |
| | ) |
| | for _ in range(cfg.decoder_depth) |
| | ] |
| | ) |
| |
|
| | self.decoder_norm = norm_layer(cfg.decoder_embed_dim) |
| | self.decoder_pred = nn.Linear( |
| | cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True |
| | ) |
| | |
| |
|
| | self.norm_pix_loss = cfg.norm_pix_loss |
| |
|
| | self.initialize_weights() |
| |
|
| | for pn, p in self.named_parameters(): |
| | if len(p.shape) == 1 or pn.endswith(".bias"): |
| | p.param_group = "no_decay" |
| | else: |
| | p.param_group = "with_decay" |
| |
|
| | def initialize_weights(self): |
| | |
| | |
| | pos_embed = get_2d_sincos_pos_embed( |
| | self.pos_embed.shape[-1], |
| | int(self.patch_embed.num_patches ** 0.5), |
| | cls_token=not self.cfg.no_cls, |
| | ) |
| | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| |
|
| | if self.decoder_pos_embed is not None: |
| | decoder_pos_embed = get_2d_sincos_pos_embed( |
| | self.decoder_pos_embed.shape[-1], |
| | int(self.patch_embed.num_patches ** 0.5), |
| | cls_token=not self.cfg.no_cls, |
| | ) |
| | self.decoder_pos_embed.data.copy_( |
| | torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) |
| | ) |
| |
|
| | |
| | w = self.patch_embed.proj.weight.data |
| | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| |
|
| | |
| | if self.cls_token is not None: |
| | torch.nn.init.normal_(self.cls_token, std=0.02) |
| |
|
| | if self.mask_token is not None: |
| | torch.nn.init.normal_(self.mask_token, std=0.02) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def patchify(self, imgs): |
| | """ |
| | imgs: (N, 3, H, W) |
| | x: (N, L, patch_size**2 *3) |
| | """ |
| | p = self.patch_embed.patch_size[0] |
| | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
| |
|
| | h = w = imgs.shape[2] // p |
| | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| | x = torch.einsum("nchpwq->nhwpqc", x) |
| | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) |
| | return x |
| |
|
| | def unpatchify(self, x): |
| | """ |
| | x: (N, L, patch_size**2 *3) |
| | imgs: (N, 3, H, W) |
| | """ |
| | p = self.patch_embed.patch_size[0] |
| | h = w = int(x.shape[1] ** 0.5) |
| | assert h * w == x.shape[1] |
| |
|
| | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
| | x = torch.einsum("nhwpqc->nchpwq", x) |
| | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
| | return imgs |
| |
|
| | def random_masking(self, x, mask_ratio): |
| | """ |
| | Perform per-sample random masking by per-sample shuffling. |
| | Per-sample shuffling is done by argsort random noise. |
| | x: [N, L, D], sequence |
| | """ |
| | N, L, D = x.shape |
| | len_keep = int(L * (1 - mask_ratio)) |
| |
|
| | noise = torch.rand(N, L, device=x.device) |
| |
|
| | |
| | ids_shuffle = torch.argsort( |
| | noise, dim=1 |
| | ) |
| | ids_restore = torch.argsort(ids_shuffle, dim=1) |
| |
|
| | |
| | ids_keep = ids_shuffle[:, :len_keep] |
| | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
| |
|
| | |
| | mask = torch.ones([N, L], device=x.device) |
| | mask[:, :len_keep] = 0 |
| | |
| | mask = torch.gather(mask, dim=1, index=ids_restore) |
| |
|
| | return x_masked, mask, ids_restore |
| |
|
| | @classmethod |
| | def build_model(cls, cfg: MaeConfig, task=None): |
| | """Build a new model instance.""" |
| |
|
| | return cls(cfg) |
| |
|
| | def forward_encoder(self, x, mask_ratio): |
| | |
| | x = self.patch_embed(x) |
| |
|
| | |
| | |
| | |
| | |
| | x = x + self.pos_embed[:, 1:, :] |
| |
|
| | |
| | if mask_ratio > 0: |
| | x, mask, ids_restore = self.random_masking(x, mask_ratio) |
| | else: |
| | mask = ids_restore = None |
| |
|
| | |
| | if self.cls_token is not None: |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | |
| | for blk in self.blocks: |
| | x = blk(x) |
| |
|
| | if self.norm is not None: |
| | x = self.norm(x) |
| |
|
| | return x, mask, ids_restore |
| |
|
| | def forward_decoder(self, x, ids_restore): |
| | |
| | x = self.decoder_embed(x) |
| |
|
| | |
| | mask_tokens = self.mask_token.repeat( |
| | x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 |
| | ) |
| | if self.cls_token is not None: |
| | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| | else: |
| | x_ = torch.cat([x, mask_tokens], dim=1) |
| |
|
| | x_ = torch.gather( |
| | x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) |
| | ) |
| |
|
| | if self.cls_token is not None: |
| | x = torch.cat([x[:, :1, :], x_], dim=1) |
| |
|
| | |
| | x = x + self.decoder_pos_embed |
| |
|
| | |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| | x = self.decoder_norm(x) |
| |
|
| | |
| | x = self.decoder_pred(x) |
| |
|
| | if self.cls_token is not None: |
| | |
| | x = x[:, 1:, :] |
| |
|
| | return x |
| |
|
| | def forward_loss(self, imgs, pred, mask): |
| | """ |
| | imgs: [N, 3, H, W] |
| | pred: [N, L, p*p*3] |
| | mask: [N, L], 0 is keep, 1 is remove, |
| | """ |
| | target = self.patchify(imgs) |
| | if self.norm_pix_loss: |
| | mean = target.mean(dim=-1, keepdim=True) |
| | var = target.var(dim=-1, keepdim=True) |
| | target = (target - mean) / (var + 1.0e-6) ** 0.5 |
| |
|
| | loss = (pred - target) ** 2 |
| | loss = loss.mean(dim=-1) |
| |
|
| | loss = (loss * mask).sum() |
| | return loss, mask.sum() |
| |
|
| | def forward(self, imgs, predictions_only=False): |
| | latent, mask, ids_restore = self.forward_encoder( |
| | imgs, self.mask_ratio if not predictions_only else 0 |
| | ) |
| |
|
| | if predictions_only: |
| | return latent |
| |
|
| | pred = self.forward_decoder(latent, ids_restore) |
| | loss, sample_size = self.forward_loss(imgs, pred, mask) |
| |
|
| | result = { |
| | "losses": {"regression": loss}, |
| | "sample_size": sample_size, |
| | } |
| | return result |
| |
|
| | def remove_pretraining_modules(self): |
| | self.decoder_embed = None |
| | self.decoder_blocks = None |
| | self.decoder_norm = None |
| | self.decoder_pos_embed = None |
| | self.decoder_pred = None |
| | self.mask_token = None |
| | if self.cfg.layer_norm_first: |
| | self.norm = None |
| |
|