Spaces:
Runtime error
Runtime error
File size: 18,944 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 |
# Copyright (c) OpenMMLab. All rights reserved.
# Part of code is modified from BEiT
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
import math
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models.backbones import BEiTViT
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor
class Conv2d(nn.Module):
"""Rewrite Conv2d module according to DALL-E code."""
def __init__(self,
n_in: int,
n_out: int,
kw: int,
use_float16: bool = True,
device: torch.device = torch.device('cpu'),
requires_grad: bool = False) -> None:
super().__init__()
w = torch.empty((n_out, n_in, kw, kw),
dtype=torch.float32,
device=device,
requires_grad=requires_grad)
w.normal_(std=1 / math.sqrt(n_in * kw**2))
b = torch.zeros((n_out, ),
dtype=torch.float32,
device=device,
requires_grad=requires_grad)
self.kw = kw
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
self.use_float16 = use_float16
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
class EncoderBlock(nn.Module):
"""Rewrite EncoderBlock module according to DALL-E code."""
def __init__(self,
n_in: int,
n_out: int,
n_layers: int,
device: torch.device = None,
requires_grad: bool = False) -> None:
super().__init__()
self.n_hid = n_out // 4
self.post_gain = 1 / (n_layers**2)
make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
self.id_path = make_conv(n_in, n_out,
1) if n_in != n_out else nn.Identity()
self.res_path = nn.Sequential(
OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(n_in, self.n_hid, 3)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, n_out, 1)),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@MODELS.register_module(name='DALL-E')
class DALLEEncoder(BaseModule):
"""DALL-E Encoder for feature extraction.
Args:
group_count (int): Number of groups in DALL-E encoder. Defaults to 4.
n_hid (int): Dimension of hidden layers. Defaults to 256.
n_blk_per_group (int): Number of blocks per group. Defaults to 2.
input_channels: (int): The channels of input images. Defaults to 3.
vocab_size (int): Vocabulary size, indicating the number of classes.
Defaults to 8192.
device (torch.device): Device of parameters. Defaults to
``torch.device('cpu')``.
requires_grad (bool): Require gradient or not. Defaults to False.
init_cfg (Union[List[dict], dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
group_count: int = 4,
n_hid: int = 256,
n_blk_per_group: int = 2,
input_channels: int = 3,
vocab_size: int = 8192,
device: torch.device = torch.device('cpu'),
requires_grad: bool = False,
init_cfg: Union[dict, List[dict], None] = None):
super().__init__(init_cfg=init_cfg)
self.input_channels = input_channels
blk_range = range(n_blk_per_group)
n_layers = group_count * n_blk_per_group
make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
make_blk = partial(
EncoderBlock,
n_layers=n_layers,
device=device,
requires_grad=requires_grad)
self.blocks = nn.Sequential(
OrderedDict([
('input', make_conv(input_channels, 1 * n_hid, 7)),
('group_1',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(1 * n_hid if i == 0 else 2 * n_hid,
2 * n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(2 * n_hid if i == 0 else 4 * n_hid,
4 * n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(4 * n_hid if i == 0 else 8 * n_hid,
8 * n_hid)) for i in blk_range],
]))),
('output',
nn.Sequential(
OrderedDict([
('relu', nn.ReLU()),
('conv',
make_conv(
8 * n_hid, vocab_size, 1, use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function of DALL-E encoder.
Args:
x (torch.Tensor): The input images with shape (B, C, H, W).
Returns:
torch.Tensor: The output with shape (B, vocab_size, h, w).
"""
x = x.float()
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model \
built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
@MODELS.register_module()
class CAEPretrainViT(BEiTViT):
"""Vision Transformer for CAE pre-training and the implementation is based
on BEiTViT.
Args:
arch (str | dict): Vision Transformer architecture. Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float, optional): The init value of gamma in
BEiTTransformerEncoderLayer.
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
arch: str = 'b',
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
out_indices: int = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
bias: bool = 'qv_bias',
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'raw',
frozen_stages: int = -1,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
layer_scale_init_value: float = None,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: dict = [
dict(type='Constant', val=1, layer=['LayerNorm']),
dict(type='TruncNormal', std=0.02, layer=['Conv2d']),
dict(type='Xavier', distribution='uniform', layer=['Linear'])
]
) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
bias=bias,
norm_cfg=norm_cfg,
final_norm=final_norm,
out_type=out_type,
with_cls_token=True,
frozen_stages=frozen_stages,
use_abs_pos_emb=use_abs_pos_emb,
use_rel_pos_bias=use_rel_pos_bias,
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
layer_scale_init_value=layer_scale_init_value,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.pos_embed.requires_grad = False
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# initialize position embedding in backbone
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.pos_embed.shape[-1],
cls_token=True)
self.pos_embed.data.copy_(pos_embed.float())
trunc_normal_(self.cls_token, std=.02)
def forward(self, x: torch.Tensor,
mask: Optional[torch.Tensor]) -> torch.Tensor:
"""Generate features for masked images.
This function generates mask images and get the hidden features for
visible patches.
The function supports two kind of forward behaviors. If the ``mask`` is
not ``None``, the forward function will be executed as masked image
modeling pre-training; if the ``mask`` is ``None``, the forward
function will call ``super().forward()``, which extract features from
images without mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (torch.Tensor, optional): Mask for input, which is of shape
B x L.
Returns:
torch.Tensor: hidden features.
"""
if mask is None:
return super().forward(x)
else:
x, _ = self.patch_embed(x)
batch_size, _, dim = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# NOTE: unmasked embeddings
x_unmasked = x[~mask].reshape(batch_size, -1, dim)
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)
pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1,
dim)
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
batch_size, -1, dim)
pos_embed_unmasked = torch.cat(
(pos_embed[:, :1], pos_embed_unmasked), dim=1)
x_unmasked = x_unmasked + pos_embed_unmasked
x_unmasked = self.drop_after_pos(x_unmasked)
for i, layer in enumerate(self.layers):
x_unmasked = layer(x=x_unmasked, rel_pos_bias=None)
if i == len(self.layers) - 1 and self.final_norm:
x_unmasked = self.norm1(x_unmasked)
return x_unmasked
@MODELS.register_module()
class CAE(BaseSelfSupervisor):
"""CAE.
Implementation of `Context Autoencoder for Self-Supervised Representation
Learning <https://arxiv.org/abs/2202.03026>`_.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of neck.
head (dict): Config dict for module of head functions.
target_generator: (dict, optional): The target_generator module to
generate targets for self-supervised learning optimization, such as
HOG, extracted features from other modules(DALL-E, CLIP), etc.
base_momentum (float): The base momentum coefficient for the target
network. Defaults to 0.0.
data_preprocessor (dict, optional): The config for preprocessing
input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (Union[List[dict], dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
target_generator: Optional[dict] = None,
base_momentum: float = 0.0,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
target_generator=target_generator,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.momentum = base_momentum
self.teacher = MODELS.build(backbone)
def init_weights(self) -> None:
"""Initialize weights."""
super().init_weights()
# init the weights of teacher with those of backbone
for param_backbone, param_teacher in zip(self.backbone.parameters(),
self.teacher.parameters()):
param_teacher.detach()
param_teacher.data.copy_(param_backbone.data)
param_teacher.requires_grad = False
def momentum_update(self) -> None:
"""Momentum update of the teacher network."""
for param_bacbone, param_teacher in zip(self.backbone.parameters(),
self.teacher.parameters()):
param_teacher.data = param_teacher.data * self.momentum + \
param_bacbone.data * (1. - self.momentum)
def extract_feat(self, inputs: torch.Tensor):
return self.backbone(inputs, mask=None)
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
mask = torch.stack([data_sample.mask for data_sample in data_samples])
mask = mask.flatten(1).to(torch.bool)
unmasked = self.backbone(inputs[0], mask)
# get the latent prediction for the masked patches
with torch.no_grad():
# inputs[0] is the prediction image
latent_target = self.teacher(inputs[0], ~mask)
latent_target = latent_target[:, 1:, :]
self.momentum_update()
pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1)
pos_embed_masked = pos_embed[:,
1:][mask].reshape(inputs[0].shape[0], -1,
pos_embed.shape[-1])
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
inputs[0].shape[0], -1, pos_embed.shape[-1])
# input the unmasked tokens and masked tokens to the decoder
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,
pos_embed_unmasked)
logits = logits.view(-1, logits.shape[-1])
# inputs[1] is the target image
logits_target = self.target_generator(inputs[1])
loss_main, loss_align = self.head.loss(logits, logits_target,
latent_pred, latent_target,
mask)
losses = dict()
losses['loss'] = loss_main + loss_align
losses['main'] = loss_main
losses['align'] = loss_align
return losses
|