Spaces:
Runtime error
Runtime error
File size: 7,398 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models import SwinTransformer
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from .base import BaseSelfSupervisor
@MODELS.register_module()
class SimMIMSwinTransformer(SwinTransformer):
"""Swin Transformer for SimMIM pre-training.
Args:
Args:
arch (str | dict): Swin Transformer architecture
Defaults to 'T'.
img_size (int | tuple): The size of input image.
Defaults to 224.
in_channels (int): The num of input channels.
Defaults to 3.
drop_rate (float): Dropout rate after embedding.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate.
Defaults to 0.1.
out_indices (tuple): Layers to be outputted. Defaults to (3, ).
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer at end
of backone. Defaults to dict(type='LN')
stage_cfgs (Sequence | dict): Extra config dict for each
stage. Defaults to empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to empty dict.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'T',
img_size: Union[Tuple[int, int], int] = 224,
in_channels: int = 3,
drop_rate: float = 0.,
drop_path_rate: float = 0.1,
out_indices: tuple = (3, ),
use_abs_pos_embed: bool = False,
with_cp: bool = False,
frozen_stages: bool = -1,
norm_eval: bool = False,
norm_cfg: dict = dict(type='LN'),
stage_cfgs: Union[Sequence, dict] = dict(),
patch_cfg: dict = dict(),
pad_small_map: bool = False,
init_cfg: Optional[dict] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
in_channels=in_channels,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
out_indices=out_indices,
use_abs_pos_embed=use_abs_pos_embed,
with_cp=with_cp,
frozen_stages=frozen_stages,
norm_eval=norm_eval,
norm_cfg=norm_cfg,
stage_cfgs=stage_cfgs,
patch_cfg=patch_cfg,
pad_small_map=pad_small_map,
init_cfg=init_cfg)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def init_weights(self) -> None:
"""Initialize weights."""
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
trunc_normal_(self.mask_token, mean=0, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights."""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor,
mask: Optional[torch.Tensor]) -> Sequence[torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
not ``None``, the forward function will be executed as masked image
modeling pre-training; if the ``mask`` is ``None``, the forward
function will call ``super().forward()``, which extract features from
images without mask.
Args:
x (torch.Tensor): Input images.
mask (torch.Tensor, optional): Masks for images.
Returns:
tuple: A tuple containing features from multi-stages.
"""
if mask is None:
return super().forward(x)
else:
x, hw_shape = self.patch_embed(x)
B, L, _ = x.shape
mask_token = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
x = x * (1. - w) + mask_token * w
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return tuple(outs)
@MODELS.register_module()
class SimMIM(BaseSelfSupervisor):
"""SimMIM.
Implementation of `SimMIM: A Simple Framework for Masked Image Modeling
<https://arxiv.org/abs/2111.09886>`_.
"""
def extract_feat(self, inputs: torch.Tensor):
return self.backbone(inputs, mask=None)
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
mask = torch.stack([data_sample.mask for data_sample in data_samples])
img_latent = self.backbone(inputs, mask)
img_rec = self.neck(img_latent[0])
loss = self.head.loss(img_rec, inputs, mask)
losses = dict(loss=loss)
return losses
|