Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import build_norm_layer | |
| from mmengine.model import BaseModule | |
| from mmpretrain.models.backbones.hivit import BlockWithRPE | |
| from mmpretrain.registry import MODELS | |
| from ..backbones.vision_transformer import TransformerEncoderLayer | |
| from ..utils import build_2d_sincos_position_embedding | |
| class PatchSplit(nn.Module): | |
| """The up-sample module used in neck (transformer pyramid network) | |
| Args: | |
| dim (int): the input dimension (channel number). | |
| fpn_dim (int): the fpn dimension (channel number). | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Defaults to ``dict(type='LN')``. | |
| """ | |
| def __init__(self, dim, fpn_dim, norm_cfg): | |
| super().__init__() | |
| _, self.norm = build_norm_layer(norm_cfg, dim) | |
| self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False) | |
| self.fpn_dim = fpn_dim | |
| def forward(self, x): | |
| B, N, H, W, C = x.shape | |
| x = self.norm(x) | |
| x = self.reduction(x) | |
| x = x.reshape(B, N, H, W, 2, 2, | |
| self.fpn_dim).permute(0, 1, 2, 4, 3, 5, | |
| 6).reshape(B, N, 2 * H, 2 * W, | |
| self.fpn_dim) | |
| return x | |
| class iTPNPretrainDecoder(BaseModule): | |
| """The neck module of iTPN (transformer pyramid network). | |
| Args: | |
| num_patches (int): The number of total patches. Defaults to 196. | |
| patch_size (int): Image patch size. Defaults to 16. | |
| in_chans (int): The channel of input image. Defaults to 3. | |
| embed_dim (int): Encoder's embedding dimension. Defaults to 512. | |
| fpn_dim (int): The fpn dimension (channel number). | |
| fpn_depth (int): The layer number of feature pyramid. | |
| decoder_embed_dim (int): Decoder's embedding dimension. | |
| Defaults to 512. | |
| decoder_depth (int): The depth of decoder. Defaults to 8. | |
| decoder_num_heads (int): Number of attention heads of decoder. | |
| Defaults to 16. | |
| mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. | |
| Defaults to 4. | |
| norm_cfg (dict): Normalization layer. Defaults to LayerNorm. | |
| reconstruction_type (str): The itpn supports 2 kinds of supervisions. | |
| Defaults to 'pixel'. | |
| num_outs (int): The output number of neck (transformer pyramid | |
| network). Defaults to 3. | |
| predict_feature_dim (int): The output dimension to supervision. | |
| Defaults to None. | |
| init_cfg (Union[List[dict], dict], optional): Initialization config | |
| dict. Defaults to None. | |
| """ | |
| def __init__(self, | |
| num_patches: int = 196, | |
| patch_size: int = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 512, | |
| fpn_dim: int = 256, | |
| fpn_depth: int = 2, | |
| decoder_embed_dim: int = 512, | |
| decoder_depth: int = 6, | |
| decoder_num_heads: int = 16, | |
| mlp_ratio: int = 4, | |
| norm_cfg: dict = dict(type='LN', eps=1e-6), | |
| reconstruction_type: str = 'pixel', | |
| num_outs: int = 3, | |
| qkv_bias: bool = True, | |
| qk_scale: Optional[bool] = None, | |
| drop_rate: float = 0.0, | |
| attn_drop_rate: float = 0.0, | |
| predict_feature_dim: Optional[float] = None, | |
| init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_patches = num_patches | |
| assert reconstruction_type in ['pixel', 'clip'], \ | |
| 'iTPN method only support `pixel` and `clip`, ' \ | |
| f'but got `{reconstruction_type}`.' | |
| self.reconstruction_type = reconstruction_type | |
| self.num_outs = num_outs | |
| self.build_transformer_pyramid( | |
| num_outs=num_outs, | |
| embed_dim=embed_dim, | |
| fpn_dim=fpn_dim, | |
| fpn_depth=fpn_depth, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop_rate=drop_rate, | |
| attn_drop_rate=attn_drop_rate, | |
| rpe=False, | |
| norm_cfg=norm_cfg, | |
| ) | |
| # merge the output | |
| self.decoder_embed = nn.ModuleList() | |
| self.decoder_embed.append( | |
| nn.Sequential( | |
| nn.LayerNorm(fpn_dim), | |
| nn.Linear(fpn_dim, decoder_embed_dim, bias=True), | |
| )) | |
| if self.num_outs >= 2: | |
| self.decoder_embed.append( | |
| nn.Sequential( | |
| nn.LayerNorm(fpn_dim), | |
| nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True), | |
| )) | |
| if self.num_outs >= 3: | |
| self.decoder_embed.append( | |
| nn.Sequential( | |
| nn.LayerNorm(fpn_dim), | |
| nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True), | |
| )) | |
| if reconstruction_type == 'pixel': | |
| self.mask_token = nn.Parameter( | |
| torch.zeros(1, 1, decoder_embed_dim)) | |
| # create new position embedding, different from that in encoder | |
| # and is not learnable | |
| self.decoder_pos_embed = nn.Parameter( | |
| torch.zeros(1, self.num_patches, decoder_embed_dim), | |
| requires_grad=False) | |
| self.decoder_blocks = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| decoder_embed_dim, | |
| decoder_num_heads, | |
| int(mlp_ratio * decoder_embed_dim), | |
| qkv_bias=True, | |
| norm_cfg=norm_cfg) for _ in range(decoder_depth) | |
| ]) | |
| self.decoder_norm_name, decoder_norm = build_norm_layer( | |
| norm_cfg, decoder_embed_dim, postfix=1) | |
| self.add_module(self.decoder_norm_name, decoder_norm) | |
| # Used to map features to pixels | |
| if predict_feature_dim is None: | |
| predict_feature_dim = patch_size**2 * in_chans | |
| self.decoder_pred = nn.Linear( | |
| decoder_embed_dim, predict_feature_dim, bias=True) | |
| else: | |
| _, norm = build_norm_layer(norm_cfg, embed_dim) | |
| self.add_module('norm', norm) | |
| def build_transformer_pyramid(self, | |
| num_outs=3, | |
| embed_dim=512, | |
| fpn_dim=256, | |
| fpn_depth=2, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0, | |
| rpe=False, | |
| norm_cfg=None): | |
| Hp = None | |
| mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim} | |
| if num_outs > 1: | |
| if embed_dim != fpn_dim: | |
| self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim) | |
| else: | |
| self.align_dim_16tofpn = None | |
| self.fpn_modules = nn.ModuleList() | |
| self.fpn_modules.append( | |
| BlockWithRPE( | |
| Hp, | |
| fpn_dim, | |
| 0, | |
| mlp_ratio, | |
| qkv_bias, | |
| qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=0., | |
| rpe=rpe, | |
| norm_cfg=norm_cfg)) | |
| self.fpn_modules.append( | |
| BlockWithRPE( | |
| Hp, | |
| fpn_dim, | |
| 0, | |
| mlp_ratio, | |
| qkv_bias, | |
| qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=0., | |
| rpe=False, | |
| norm_cfg=norm_cfg, | |
| )) | |
| self.align_dim_16to8 = nn.Linear( | |
| mlvl_dims['8'], fpn_dim, bias=False) | |
| self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg) | |
| self.block_16to8 = nn.Sequential(*[ | |
| BlockWithRPE( | |
| Hp, | |
| fpn_dim, | |
| 0, | |
| mlp_ratio, | |
| qkv_bias, | |
| qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=0., | |
| rpe=rpe, | |
| norm_cfg=norm_cfg, | |
| ) for _ in range(fpn_depth) | |
| ]) | |
| if num_outs > 2: | |
| self.align_dim_8to4 = nn.Linear( | |
| mlvl_dims['4'], fpn_dim, bias=False) | |
| self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg) | |
| self.block_8to4 = nn.Sequential(*[ | |
| BlockWithRPE( | |
| Hp, | |
| fpn_dim, | |
| 0, | |
| mlp_ratio, | |
| qkv_bias, | |
| qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=0., | |
| rpe=rpe, | |
| norm_cfg=norm_cfg, | |
| ) for _ in range(fpn_depth) | |
| ]) | |
| self.fpn_modules.append( | |
| BlockWithRPE( | |
| Hp, | |
| fpn_dim, | |
| 0, | |
| mlp_ratio, | |
| qkv_bias, | |
| qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=0., | |
| rpe=rpe, | |
| norm_cfg=norm_cfg)) | |
| def init_weights(self) -> None: | |
| """Initialize position embedding and mask token of MAE decoder.""" | |
| super().init_weights() | |
| if self.reconstruction_type == 'pixel': | |
| decoder_pos_embed = build_2d_sincos_position_embedding( | |
| int(self.num_patches**.5), | |
| self.decoder_pos_embed.shape[-1], | |
| cls_token=False) | |
| self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) | |
| torch.nn.init.normal_(self.mask_token, std=.02) | |
| else: | |
| self.rescale_init_weight() | |
| def rescale_init_weight(self) -> None: | |
| """Rescale the initialized weights.""" | |
| def rescale(param, layer_id): | |
| param.div_(math.sqrt(2.0 * layer_id)) | |
| for layer_id, layer in enumerate(self.fpn_modules): | |
| if isinstance(layer, BlockWithRPE): | |
| if layer.attn is not None: | |
| rescale(layer.attn.proj.weight.data, layer_id + 1) | |
| rescale(layer.mlp.fc2.weight.data, layer_id + 1) | |
| def decoder_norm(self): | |
| """The normalization layer of decoder.""" | |
| return getattr(self, self.decoder_norm_name) | |
| def forward(self, | |
| x: torch.Tensor, | |
| ids_restore: torch.Tensor = None) -> torch.Tensor: | |
| """The forward function. | |
| The process computes the visible patches' features vectors and the mask | |
| tokens to output feature vectors, which will be used for | |
| reconstruction. | |
| Args: | |
| x (torch.Tensor): hidden features, which is of shape | |
| B x (L * mask_ratio) x C. | |
| ids_restore (torch.Tensor): ids to restore original image. | |
| Returns: | |
| torch.Tensor: The reconstructed feature vectors, which is of | |
| shape B x (num_patches) x C. | |
| """ | |
| features = x[:2] | |
| x = x[-1] | |
| B, L, _ = x.shape | |
| x = x[..., None, None, :] | |
| Hp = Wp = math.sqrt(L) | |
| outs = [x] if self.align_dim_16tofpn is None else [ | |
| self.align_dim_16tofpn(x) | |
| ] | |
| if self.num_outs >= 2: | |
| x = self.block_16to8( | |
| self.split_16to8(x) + self.align_dim_16to8(features[1])) | |
| outs.append(x) | |
| if self.num_outs >= 3: | |
| x = self.block_8to4( | |
| self.split_8to4(x) + self.align_dim_8to4(features[0])) | |
| outs.append(x) | |
| if self.num_outs > 3: | |
| outs = [ | |
| out.reshape(B, Hp, Wp, *out.shape[-3:]).permute( | |
| 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3], | |
| Wp * out.shape[-2]).contiguous() | |
| for out in outs | |
| ] | |
| if self.num_outs >= 4: | |
| outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) | |
| if self.num_outs >= 5: | |
| outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) | |
| for i, out in enumerate(outs): | |
| out = self.fpn_modules[i](out) | |
| outs[i] = out | |
| if self.reconstruction_type == 'pixel': | |
| feats = [] | |
| for feat, layer in zip(outs, self.decoder_embed): | |
| x = layer(feat).reshape(B, L, -1) | |
| # append mask tokens to sequence | |
| mask_tokens = self.mask_token.repeat( | |
| x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
| x = torch.cat([x, mask_tokens], dim=1) | |
| x = torch.gather( | |
| x, | |
| dim=1, | |
| index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) | |
| feats.append(x) | |
| x = feats.pop(0) | |
| # add pos embed | |
| x = x + self.decoder_pos_embed | |
| for i, feat in enumerate(feats): | |
| x = x + feats[i] | |
| # apply Transformer blocks | |
| for i, blk in enumerate(self.decoder_blocks): | |
| x = blk(x) | |
| x = self.decoder_norm(x) | |
| x = self.decoder_pred(x) | |
| return x | |
| else: | |
| feats = [] | |
| for feat, layer in zip(outs, self.decoder_embed): | |
| x = layer(feat).reshape(B, L, -1) | |
| feats.append(x) | |
| x = feats.pop(0) | |
| for i, feat in enumerate(feats): | |
| x = x + feats[i] | |
| x = self.norm(x) | |
| return x | |