Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from mmengine.model.weight_init import trunc_normal_ | |
from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge | |
from mmpretrain.registry import MODELS | |
from mmpretrain.structures import DataSample | |
from ..utils import build_2d_sincos_position_embedding | |
from .base import BaseSelfSupervisor | |
class iTPNHiViT(HiViT): | |
"""HiViT for iTPN pre-training. | |
Args: | |
img_size (int | tuple): Input image size. Defaults to 224. | |
patch_size (int | tuple): The patch size. Defaults to 16. | |
inner_patches (int): Inner patch. Defaults to 4. | |
stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim | |
in the first two stages. Defaults to 3. | |
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in | |
the last stage. Defaults to 4. | |
qkv_bias (bool): Enable bias for qkv projections if True. | |
qk_scale (float): The number of divider after q@k. Default to None. | |
drop_rate (float): Probability of an element to be zeroed. | |
Defaults to 0. | |
attn_drop_rate (float): The drop out rate for attention output weights. | |
Defaults to 0. | |
drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN')``. | |
ape (bool): If True, add absolute position embedding to | |
the patch embedding. | |
rpe (bool): If True, add relative position embedding to | |
the patch embedding. | |
layer_scale_init_value (float): Layer-scale init values. Defaults to 0. | |
mask_ratio (bool): The ratio of total number of patches to be masked. | |
Defaults to 0.75. | |
reconstruction_type (str): The reconstruction of self-supervised | |
learning. Defaults to 'pixel'. | |
""" | |
def __init__( | |
self, | |
arch='base', | |
img_size: int = 224, | |
patch_size: int = 16, | |
inner_patches: int = 4, | |
stem_mlp_ratio: int = 3., | |
mlp_ratio: int = 4., | |
qkv_bias: bool = True, | |
qk_scale: Optional[bool] = None, | |
drop_rate: float = 0.0, | |
attn_drop_rate: float = 0.0, | |
drop_path_rate: float = 0.0, | |
norm_cfg: dict = dict(type='LN', eps=1e-6), | |
ape: bool = True, | |
rpe: bool = False, | |
layer_scale_init_value: float = 0.0, | |
mask_ratio: float = 0.75, | |
reconstruction_type: str = 'pixel', | |
): | |
super().__init__( | |
arch=arch, | |
img_size=img_size, | |
patch_size=patch_size, | |
inner_patches=inner_patches, | |
stem_mlp_ratio=stem_mlp_ratio, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop_rate=drop_rate, | |
attn_drop_rate=attn_drop_rate, | |
drop_path_rate=drop_path_rate, | |
norm_cfg=norm_cfg, | |
ape=ape, | |
rpe=rpe, | |
layer_scale_init_value=layer_scale_init_value) | |
self.pos_embed.requires_grad = False | |
self.mask_ratio = mask_ratio | |
assert reconstruction_type in ['pixel', 'clip'], \ | |
'iTPN method only support `pixel` and `clip`, ' \ | |
f'but got `{reconstruction_type}`.' | |
self.reconstruction_type = reconstruction_type | |
self.num_patches = self.patch_embed.num_patches | |
if reconstruction_type == 'clip': | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) | |
def init_weights(self) -> None: | |
"""Initialize position embedding, patch embedding and cls token.""" | |
super().apply(self._init_weights) | |
if self.reconstruction_type == 'clip': | |
trunc_normal_(self.mask_token, std=0.02) | |
self.rescale_init_weight() | |
else: | |
pos_embed = build_2d_sincos_position_embedding( | |
int(self.num_patches**.5), | |
self.pos_embed.shape[-1], | |
cls_token=False) | |
self.pos_embed.data.copy_(pos_embed.float()) | |
w = self.patch_embed.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
def rescale_init_weight(self) -> None: | |
"""Rescale the initialized weights.""" | |
def rescale(param, layer_id): | |
param.div_(math.sqrt(2.0 * layer_id)) | |
for layer_id, layer in enumerate(self.blocks): | |
if isinstance(layer, BlockWithRPE): | |
if layer.attn is not None: | |
rescale(layer.attn.proj.weight.data, layer_id + 1) | |
rescale(layer.mlp.fc2.weight.data, layer_id + 1) | |
def masking_id(self, batch_size, mask_ratio): | |
N, L = batch_size, self.pos_embed.size(1) | |
len_keep = int(L * (1 - mask_ratio)) | |
noise = torch.rand( | |
N, L, device=self.pos_embed.device) # noise in [0, 1] | |
# sort noise for each sample | |
ids_shuffle = torch.argsort( | |
noise, dim=1) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([N, L], device=self.pos_embed.device) | |
mask[:, :ids_keep.size(1)] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return ids_keep, ids_restore, mask | |
def forward_pixel( | |
self, | |
x: torch.Tensor, | |
mask: Optional[bool] = True | |
) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: | |
"""Generate features for masked images. | |
The function supports two kind of forward behaviors. If the ``mask`` is | |
``True``, the function will generate mask to masking some patches | |
randomly and get the hidden features for visible patches, which means | |
the function will be executed as masked imagemodeling pre-training; | |
if the ``mask`` is ``None`` or ``False``, the forward function will | |
call ``super().forward()``, which extract features from images without | |
mask. | |
Args: | |
x (torch.Tensor): Input images, which is of shape B x C x H x W. | |
mask (bool, optional): To indicate whether the forward function | |
generating ``mask`` or not. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, | |
mask and the ids to restore original image. | |
- ``x`` (torch.Tensor): hidden features, which is of shape | |
B x (L * mask_ratio) x C. | |
- ``mask`` (torch.Tensor): mask used to mask image. | |
- ``ids_restore`` (torch.Tensor): ids to restore original image. | |
""" | |
if mask is None or False: | |
return super().forward(x) | |
else: | |
B, C, H, W = x.shape | |
ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) | |
x = self.patch_embed(x) | |
x = torch.gather( | |
x, | |
dim=1, | |
index=ids_keep[:, :, None, None, | |
None].expand(-1, -1, *x.shape[2:])) | |
outs = [] | |
for blk in self.blocks[:-self.num_main_blocks]: | |
if isinstance(blk, PatchMerge): | |
outs.append(x) | |
x = blk(x) | |
x = x[..., 0, 0, :] | |
if self.ape: | |
pos_embed = self.interpolate_pos_encoding(x, H, W) | |
pos_embed = torch.gather( | |
pos_embed.expand(B, -1, -1), | |
dim=1, | |
index=ids_keep[:, :, None].expand(-1, -1, | |
pos_embed.shape[2]), | |
) | |
x = x + pos_embed | |
x = self.pos_drop(x) | |
for blk in self.blocks[-self.num_main_blocks:]: | |
x = blk(x) | |
outs.append(x) | |
return (tuple(outs), mask, ids_restore) | |
def forward_clip(self, | |
x: torch.Tensor, | |
mask: Optional[bool] = True) -> Tuple: | |
"""Generate features for masked images. | |
The function supports two kind of forward behaviors. If the ``mask`` is | |
``True``, the function will generate mask to masking some patches | |
randomly and get the hidden features for visible patches, which means | |
the function will be executed as masked imagemodeling pre-training; | |
if the ``mask`` is ``None`` or ``False``, the forward function will | |
call ``super().forward()``, which extract features from images without | |
mask. | |
Args: | |
x (torch.Tensor): Input images, which is of shape B x C x H x W. | |
mask (bool, optional): To indicate whether the forward function | |
generating ``mask`` or not. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, | |
mask and the ids to restore original image. | |
- ``x`` (torch.Tensor): hidden features, which is of shape | |
B x (L * mask_ratio) x C. | |
- ``mask`` (torch.Tensor): mask used to mask image. | |
- ``ids_restore`` (torch.Tensor): ids to restore original image. | |
""" | |
if mask is None or False: | |
return super().forward(x) | |
else: | |
B, C, H, W = x.shape | |
x = self.patch_embed(x) | |
outs = [] | |
for blk in self.blocks[:-self.num_main_blocks]: | |
if isinstance(blk, PatchMerge): | |
outs.append(x) | |
x = blk(x) | |
x = x[..., 0, 0, :] | |
B, L, _ = x.shape | |
mask_token = self.mask_token.expand(B, L, -1) | |
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) | |
x = x * (1. - w) + mask_token * w | |
if self.ape: | |
pos_embed = self.interpolate_pos_encoding(x, H, W) | |
x = x + pos_embed | |
x = self.pos_drop(x) | |
rpe_index = True if self.rpe else None | |
for blk in self.blocks[-self.num_main_blocks:]: | |
x = blk(x, rpe_index) | |
outs.append(x) | |
return tuple(outs) | |
def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple: | |
"""Generate features for masked images. | |
The function supports two kind of forward behaviors. If the ``mask`` is | |
``True``, the function will generate mask to masking some patches | |
randomly and get the hidden features for visible patches, which means | |
the function will be executed as masked imagemodeling pre-training; | |
if the ``mask`` is ``None`` or ``False``, the forward function will | |
call ``super().forward()``, which extract features from images without | |
mask. | |
Args: | |
x (torch.Tensor): Input images, which is of shape B x C x H x W. | |
mask (bool, optional): To indicate whether the forward function | |
generating ``mask`` or not. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, | |
mask and the ids to restore original image. | |
- ``x`` (torch.Tensor): hidden features, which is of shape | |
B x (L * mask_ratio) x C. | |
- ``mask`` (torch.Tensor): mask used to mask image. | |
- ``ids_restore`` (torch.Tensor): ids to restore original image. | |
""" | |
if self.reconstruction_type == 'pixel': | |
return self.forward_pixel(x, mask) | |
return self.forward_clip(x, mask) | |
class iTPN(BaseSelfSupervisor): | |
"""iTPN. | |
Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid | |
Networks <https://arxiv.org/abs/2211.12735>`_. | |
""" | |
def extract_feat(self, inputs: torch.Tensor): | |
return self.backbone(inputs, mask=None) | |
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], | |
**kwargs) -> Dict[str, torch.Tensor]: | |
"""The forward function in training. | |
Args: | |
inputs (torch.Tensor): The input images. | |
data_samples (List[DataSample]): All elements required | |
during the forward function. | |
Returns: | |
Dict[str, torch.Tensor]: A dictionary of loss components. | |
""" | |
if self.backbone.reconstruction_type == 'pixel': | |
latent, mask, ids_restore = self.backbone(inputs) | |
pred = self.neck(latent, ids_restore) | |
loss = self.head.loss(pred, inputs, mask) | |
else: | |
mask = torch.stack( | |
[data_sample.mask for data_sample in data_samples]) | |
img_latent = self.backbone(inputs[0], mask) | |
# inputs[1] is the target image | |
with torch.no_grad(): | |
target = self.target_generator(inputs[1])[0] | |
target = target.detach() | |
# iTPN contains a neck module | |
feats = self.neck(img_latent) | |
loss = self.head.loss(feats, target[:, 1:, :], mask) | |
losses = dict(loss=loss) | |
return losses | |