Spaces:
Runtime error
Runtime error
File size: 7,684 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from mmengine.runner.checkpoint import _load_checkpoint
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_clip_model
from .base import BaseSelfSupervisor
from .mae import MAEViT
@MODELS.register_module()
class CLIPGenerator(nn.Module):
"""Get the features and attention from the last layer of CLIP.
This module is used to generate target features in masked image modeling.
Args:
tokenizer_path (str): The path of the checkpoint of CLIP.
"""
def __init__(self, tokenizer_path: str) -> None:
super().__init__()
self.tokenizer_path = tokenizer_path
self.tokenizer = build_clip_model(
_load_checkpoint(self.tokenizer_path), False)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the features and attention from the last layer of CLIP.
Args:
x (torch.Tensor): The input image, which is of shape (N, 3, H, W).
Returns:
Tuple[torch.Tensor, torch.Tensor]: The features and attention from
the last layer of CLIP, which are of shape (N, L, C) and (N, L, L),
respectively.
"""
# use the visual branch of CLIP to get the features
assert self.tokenizer is not None, 'Please check whether the ' \
'`self.tokenizer` is initialized correctly.'
clip_features = self.tokenizer.encode_image(x)
return clip_features
@MODELS.register_module()
class MILANViT(MAEViT):
"""Vision Transformer for MILAN pre-training.
Implementation of the encoder for `MILAN: Masked Image Pretraining on
Language Assisted Representation <https://arxiv.org/abs/2208.06049>`_.
This module inherits from MAEViT and only overrides the forward function
and replace random masking with attention masking.
"""
def attention_masking(
self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate attention mask for MILAN.
This is what is different from MAEViT, which uses random masking.
Attention masking generates attention mask for MILAN, according to
importance. The higher the importance, the more likely the patch is
kept.
Args:
x (torch.Tensor): Input images, which is of shape B x L x C.
mask_ratio (float): The ratio of patches to be masked.
importance (torch.Tensor): Importance of each patch, which is of
shape B x L.
Returns:
Tuple[torch.Tensor, ...]:
- ``x_masked``: masked image
- ``ids_restore``: the ids to restore original image
- ``ids_keep``: ids of the kept patches
- ``ids_dump``: ids of the removed patches
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = importance.to(x.device) # large is keep, small is remove
# sort noise for each sample
ids_shuffle = torch.multinomial(noise, L, replacement=False)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
ids_dump = ids_shuffle[:, len_keep:]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, ids_restore, ids_keep, ids_dump
def forward(
self,
x: torch.Tensor,
importance: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the
``importance`` is ``None``, the function generates mask and masks some
patches randomly and get the hidden features for visible patches. The
mask is generated by importance. The higher the importance, the more
likely the patch is kept. The importance is calculated by CLIP.
The higher the CLIP score, the more likely the patch is kept. The CLIP
score is calculated by cross attention between the class token and all
other tokens from the last layer.
If the ``importance`` is ``torch.Tensor``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
importance (torch.Tensor, optional): Importance of each patch,
which is of shape B x L.
Returns:
Tuple[torch.Tensor, ...]: masked image, the ids to restore original
image, ids of the kept patches, ids of the removed patches.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
- ``ids_keep`` (torch.Tensor): ids of the kept patches.
- ``ids_dump`` (torch.Tensor): ids of the removed patches.
"""
if importance is None:
return super(MAEViT, self).forward(x)
else:
B = x.shape[0]
x = self.patch_embed(x)[0]
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, ids_restore, ids_keep, ids_dump = self.attention_masking(
x, self.mask_ratio, importance)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for _, layer in enumerate(self.layers):
x = layer(x)
# Use final norm
x = self.norm1(x)
return x, ids_restore, ids_keep, ids_dump
@MODELS.register_module()
class MILAN(BaseSelfSupervisor):
"""MILAN.
Implementation of `MILAN: Masked Image Pretraining on Language Assisted
Representation <https://arxiv.org/abs/2208.06049>`_.
"""
def extract_feat(self, inputs: torch.Tensor):
return self.backbone(inputs, importance=None)
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (torch.Tensor): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
clip_feature, importance = self.target_generator(inputs)
importance = importance[:, 0, 1:]
latent, ids_restore, ids_keep, ids_dump = self.backbone(
inputs, importance)
pred = self.neck(latent, ids_restore, ids_keep, ids_dump)
loss = self.head.loss(pred, clip_feature)
losses = dict(loss=loss)
return losses
|