KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
18.9 kB
# Copyright (c) OpenMMLab. All rights reserved.
# Part of code is modified from BEiT
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
import math
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models.backbones import BEiTViT
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor
class Conv2d(nn.Module):
"""Rewrite Conv2d module according to DALL-E code."""
def __init__(self,
n_in: int,
n_out: int,
kw: int,
use_float16: bool = True,
device: torch.device = torch.device('cpu'),
requires_grad: bool = False) -> None:
super().__init__()
w = torch.empty((n_out, n_in, kw, kw),
dtype=torch.float32,
device=device,
requires_grad=requires_grad)
w.normal_(std=1 / math.sqrt(n_in * kw**2))
b = torch.zeros((n_out, ),
dtype=torch.float32,
device=device,
requires_grad=requires_grad)
self.kw = kw
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
self.use_float16 = use_float16
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
class EncoderBlock(nn.Module):
"""Rewrite EncoderBlock module according to DALL-E code."""
def __init__(self,
n_in: int,
n_out: int,
n_layers: int,
device: torch.device = None,
requires_grad: bool = False) -> None:
super().__init__()
self.n_hid = n_out // 4
self.post_gain = 1 / (n_layers**2)
make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
self.id_path = make_conv(n_in, n_out,
1) if n_in != n_out else nn.Identity()
self.res_path = nn.Sequential(
OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(n_in, self.n_hid, 3)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, n_out, 1)),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@MODELS.register_module(name='DALL-E')
class DALLEEncoder(BaseModule):
"""DALL-E Encoder for feature extraction.
Args:
group_count (int): Number of groups in DALL-E encoder. Defaults to 4.
n_hid (int): Dimension of hidden layers. Defaults to 256.
n_blk_per_group (int): Number of blocks per group. Defaults to 2.
input_channels: (int): The channels of input images. Defaults to 3.
vocab_size (int): Vocabulary size, indicating the number of classes.
Defaults to 8192.
device (torch.device): Device of parameters. Defaults to
``torch.device('cpu')``.
requires_grad (bool): Require gradient or not. Defaults to False.
init_cfg (Union[List[dict], dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
group_count: int = 4,
n_hid: int = 256,
n_blk_per_group: int = 2,
input_channels: int = 3,
vocab_size: int = 8192,
device: torch.device = torch.device('cpu'),
requires_grad: bool = False,
init_cfg: Union[dict, List[dict], None] = None):
super().__init__(init_cfg=init_cfg)
self.input_channels = input_channels
blk_range = range(n_blk_per_group)
n_layers = group_count * n_blk_per_group
make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
make_blk = partial(
EncoderBlock,
n_layers=n_layers,
device=device,
requires_grad=requires_grad)
self.blocks = nn.Sequential(
OrderedDict([
('input', make_conv(input_channels, 1 * n_hid, 7)),
('group_1',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(1 * n_hid if i == 0 else 2 * n_hid,
2 * n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(2 * n_hid if i == 0 else 4 * n_hid,
4 * n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(4 * n_hid if i == 0 else 8 * n_hid,
8 * n_hid)) for i in blk_range],
]))),
('output',
nn.Sequential(
OrderedDict([
('relu', nn.ReLU()),
('conv',
make_conv(
8 * n_hid, vocab_size, 1, use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function of DALL-E encoder.
Args:
x (torch.Tensor): The input images with shape (B, C, H, W).
Returns:
torch.Tensor: The output with shape (B, vocab_size, h, w).
"""
x = x.float()
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model \
built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
@MODELS.register_module()
class CAEPretrainViT(BEiTViT):
"""Vision Transformer for CAE pre-training and the implementation is based
on BEiTViT.
Args:
arch (str | dict): Vision Transformer architecture. Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float, optional): The init value of gamma in
BEiTTransformerEncoderLayer.
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
arch: str = 'b',
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
out_indices: int = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
bias: bool = 'qv_bias',
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'raw',
frozen_stages: int = -1,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
layer_scale_init_value: float = None,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: dict = [
dict(type='Constant', val=1, layer=['LayerNorm']),
dict(type='TruncNormal', std=0.02, layer=['Conv2d']),
dict(type='Xavier', distribution='uniform', layer=['Linear'])
]
) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
bias=bias,
norm_cfg=norm_cfg,
final_norm=final_norm,
out_type=out_type,
with_cls_token=True,
frozen_stages=frozen_stages,
use_abs_pos_emb=use_abs_pos_emb,
use_rel_pos_bias=use_rel_pos_bias,
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
layer_scale_init_value=layer_scale_init_value,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.pos_embed.requires_grad = False
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# initialize position embedding in backbone
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.pos_embed.shape[-1],
cls_token=True)
self.pos_embed.data.copy_(pos_embed.float())
trunc_normal_(self.cls_token, std=.02)
def forward(self, x: torch.Tensor,
mask: Optional[torch.Tensor]) -> torch.Tensor:
"""Generate features for masked images.
This function generates mask images and get the hidden features for
visible patches.
The function supports two kind of forward behaviors. If the ``mask`` is
not ``None``, the forward function will be executed as masked image
modeling pre-training; if the ``mask`` is ``None``, the forward
function will call ``super().forward()``, which extract features from
images without mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (torch.Tensor, optional): Mask for input, which is of shape
B x L.
Returns:
torch.Tensor: hidden features.
"""
if mask is None:
return super().forward(x)
else:
x, _ = self.patch_embed(x)
batch_size, _, dim = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# NOTE: unmasked embeddings
x_unmasked = x[~mask].reshape(batch_size, -1, dim)
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)
pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1,
dim)
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
batch_size, -1, dim)
pos_embed_unmasked = torch.cat(
(pos_embed[:, :1], pos_embed_unmasked), dim=1)
x_unmasked = x_unmasked + pos_embed_unmasked
x_unmasked = self.drop_after_pos(x_unmasked)
for i, layer in enumerate(self.layers):
x_unmasked = layer(x=x_unmasked, rel_pos_bias=None)
if i == len(self.layers) - 1 and self.final_norm:
x_unmasked = self.norm1(x_unmasked)
return x_unmasked
@MODELS.register_module()
class CAE(BaseSelfSupervisor):
"""CAE.
Implementation of `Context Autoencoder for Self-Supervised Representation
Learning <https://arxiv.org/abs/2202.03026>`_.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of neck.
head (dict): Config dict for module of head functions.
target_generator: (dict, optional): The target_generator module to
generate targets for self-supervised learning optimization, such as
HOG, extracted features from other modules(DALL-E, CLIP), etc.
base_momentum (float): The base momentum coefficient for the target
network. Defaults to 0.0.
data_preprocessor (dict, optional): The config for preprocessing
input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (Union[List[dict], dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
target_generator: Optional[dict] = None,
base_momentum: float = 0.0,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
target_generator=target_generator,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.momentum = base_momentum
self.teacher = MODELS.build(backbone)
def init_weights(self) -> None:
"""Initialize weights."""
super().init_weights()
# init the weights of teacher with those of backbone
for param_backbone, param_teacher in zip(self.backbone.parameters(),
self.teacher.parameters()):
param_teacher.detach()
param_teacher.data.copy_(param_backbone.data)
param_teacher.requires_grad = False
def momentum_update(self) -> None:
"""Momentum update of the teacher network."""
for param_bacbone, param_teacher in zip(self.backbone.parameters(),
self.teacher.parameters()):
param_teacher.data = param_teacher.data * self.momentum + \
param_bacbone.data * (1. - self.momentum)
def extract_feat(self, inputs: torch.Tensor):
return self.backbone(inputs, mask=None)
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
mask = torch.stack([data_sample.mask for data_sample in data_samples])
mask = mask.flatten(1).to(torch.bool)
unmasked = self.backbone(inputs[0], mask)
# get the latent prediction for the masked patches
with torch.no_grad():
# inputs[0] is the prediction image
latent_target = self.teacher(inputs[0], ~mask)
latent_target = latent_target[:, 1:, :]
self.momentum_update()
pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1)
pos_embed_masked = pos_embed[:,
1:][mask].reshape(inputs[0].shape[0], -1,
pos_embed.shape[-1])
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
inputs[0].shape[0], -1, pos_embed.shape[-1])
# input the unmasked tokens and masked tokens to the decoder
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,
pos_embed_unmasked)
logits = logits.view(-1, logits.shape[-1])
# inputs[1] is the target image
logits_target = self.target_generator(inputs[1])
loss_main, loss_align = self.head.loss(logits, logits_target,
latent_pred, latent_target,
mask)
losses = dict()
losses['loss'] = loss_main + loss_align
losses['main'] = loss_main
losses['align'] = loss_align
return losses